// Copyright 2018, Intel Corp.

#include "tile/targets/cpu/compiler.h"

#include <llvm/Bitcode/BitcodeWriter.h>
#include <llvm/ExecutionEngine/ExecutionEngine.h>
#include <llvm/IR/IRBuilder.h>
#include <llvm/IR/LegacyPassManager.h>
#include <llvm/IR/Module.h>
#include <llvm/IR/Verifier.h>
#include <llvm/Support/DynamicLibrary.h>
#include <llvm/Support/TargetRegistry.h>
#include <llvm/Support/ToolOutputFile.h>
#include <llvm/Transforms/IPO/PassManagerBuilder.h>
#include <llvm/Transforms/Utils/Cloning.h>

#include <algorithm>
#include <deque>
#include <memory>
#include <utility>
#include <vector>

#include <half.hpp>

#include "base/util/lookup.h"
#include "tile/stripe/stripe.h"
#include "tile/targets/cpu/executable.h"
#include "tile/targets/cpu/link_names.h"

namespace vertexai {
namespace tile {
namespace targets {
namespace cpu {

#define VECTOR_MEM_ALIGNMENT 16

class Error : public std::runtime_error {
 public:
  using std::runtime_error::runtime_error;
};

Compiler::Compiler(llvm::LLVMContext* context, const Config& config)
    : context_(*context), builder_{context_}, config_{config}, arenaSize_(0) {
  static std::once_flag init_once;
  std::call_once(init_once, []() {
    LLVMInitializeNativeTarget();
    LLVMLinkInMCJIT();
    LLVMInitializeNativeAsmPrinter();
    LLVMInitializeNativeAsmParser();
  });
}

ProgramModule Compiler::CompileProgram(const stripe::Block& program) {
  IVLOG(4, program);
  // Compile each block in this program into a function within an LLVM module.
  ProgramModule ret;
  ret.module = std::make_unique<llvm::Module>("stripe", context_);
  module_ = ret.module.get();

  auto targetTriple = llvm::sys::getProcessTriple();
  std::string errorMessage;
  auto target = llvm::TargetRegistry::lookupTarget(targetTriple, errorMessage);
  std::unique_ptr<llvm::TargetMachine> machine(target->createTargetMachine(targetTriple, "generic", "", {}, {}));
  module_->setDataLayout(machine->createDataLayout());
  module_->setTargetTriple(targetTriple);

  GenerateArena(program);
  llvm::Function* main = CompileBlock(program);
  ret.externals = external_funcptrs_;
  // Generate a stub function we can invoke from the outside, passing buffers
  // as an array of generic pointers.
  GenerateInvoker(program, main);
  // Improve the simple-minded IR we've just generated by running module-level
  // optimization passes; among many other things, this will streamline our
  // loops to eliminate most branches and inline most block function calls.
  llvm::PassManagerBuilder pmb;
  pmb.OptLevel = 3;
  pmb.SizeLevel = 0;
  pmb.SLPVectorize = true;
  pmb.LoopVectorize = true;
  pmb.MergeFunctions = true;
  llvm::legacy::PassManager modopt;
  pmb.populateModulePassManager(modopt);
  if (config_.print_llvm_ir_simple) {
    llvm::errs() << "LLVM IR, unoptimized: ================\n";
    module_->print(llvm::errs(), nullptr);
  }
  if (llvm::verifyModule(*module_, &llvm::errs())) {
    throw std::runtime_error("Byte");
  }
  modopt.run(*module_);
  if (config_.print_llvm_ir_optimized) {
    llvm::errs() << "LLVM IR, after optimization: ================\n";
    module_->print(llvm::errs(), nullptr);
  }
  if (config_.print_assembly) {
    llvm::errs() << "Assembly code: ================\n";
    PrintOutputAssembly(machine.get());
  }
  // Wrap the finished module and the parameter names into a ProgramModule.
  for (auto& ref : program.refs) {
    if (ref.has_tag("user")) {
      ret.parameters.push_back(ref.into());
    }
  }
  module_ = nullptr;
  assert(ret.module);
  return ret;
}

Compiler::Compiler(llvm::LLVMContext* context, llvm::Module* module, const Config& config)
    : context_(*context), builder_{context_}, module_(module), config_{config}, arenaSize_(0) {
  // This private constructor sets up a nested instance which will
  // process a nested block, generating output into the same module as its
  // containing compiler instance.
}

void Compiler::GenerateInvoker(const stripe::Block& program, llvm::Function* main) {
  // Generate a wrapper function for this program, so that we can call it
  // generically through a single C function pointer type, no matter how many
  // buffer parameters it expects. From C, we will prepare a vector of void*,
  // containing the parameter buffer data pointers; the wrapper will extract
  // each data pointer, then pass each one as a parameter when it calls the
  // program's top-level block function.
  assert(!program.has_tag("cpu_thread"));
  // LLVM doesn't have the notion of a void pointer, so we'll pretend all of
  // these buffers are arrays of int8, then bitcast later.
  llvm::Type* arrayptr = builder_.getInt8PtrTy()->getPointerTo();
  llvm::Type* voidtype = builder_.getVoidTy();
  auto invoker_type = llvm::FunctionType::get(voidtype, {arrayptr}, false);
  auto linkage = llvm::Function::ExternalLinkage;
  auto invoker = llvm::Function::Create(invoker_type, linkage, invoker_name_, module_);
  auto block = llvm::BasicBlock::Create(context_, "block", invoker);
  builder_.SetInsertPoint(block);
  // We'll look up the kernel by name and implicitly bitcast it so we can call
  // it using our int32-pointers in place of whatever it actually expects;
  // LLVM will tolerate this mismatch when we use getOrInsertFunction.
  auto ai = invoker->arg_begin();
  llvm::Value* argvec = &(*ai);
  // The body of the invoker will compute the element pointer for each
  // argument value in order, then load the value.
  std::vector<llvm::Value*> args;
  std::vector<llvm::Value*> allocs;
  {
    if (arenaSize_) {
      IVLOG(1, "Arena size: " << arenaSize_);
      // allocate the arena on the heap. This way static initialization order is never an issue.
      auto buffer = Malloc(arenaSize_);
      auto arena_gval = module_->getNamedGlobal(arena_name_);
      auto arenatype = llvm::ArrayType::get(builder_.getInt8Ty(), 1)->getPointerTo();
      builder_.CreateStore(builder_.CreateBitCast(buffer, arenatype), arena_gval);
      allocs.push_back(buffer);
    }
    unsigned i = 0;
    for (auto& ref : program.refs) {
      if (ref.has_tag("user")) {
        // The refinement must be provided as a parameter by the user
        llvm::Value* index = builder_.getInt32(i++);
        std::vector<llvm::Value*> idxList{index};
        llvm::Value* elptr = builder_.CreateGEP(argvec, idxList);
        llvm::Value* elval = builder_.CreateLoad(elptr, ref.into());
        llvm::Type* eltype = CType(ref.interior_shape.type)->getPointerTo();
        args.push_back(builder_.CreateBitCast(elval, eltype));
      } else if (ref.has_tag("tmp")) {
        // Allocate a temporary buffer for this refinement
        auto buffer = Malloc(ref.interior_shape.byte_size());
        allocs.push_back(buffer);
        llvm::Type* buftype = CType(ref.interior_shape.type)->getPointerTo();
        args.push_back(builder_.CreateBitCast(buffer, buftype));
      } else {
        throw std::runtime_error("Top-level refinement missing #user or #tmp");
      }
      args.back()->setName(ref.into());
    }
  }
  // After passing in buffer pointers, we also provide an initial value for
  // each index; since this is the outermost block, all indexes begin at zero.
  for (unsigned i = 0; i < program.idxs.size(); ++i) {
    args.push_back(IndexConst(0));
  }
  // Having built the argument list, we'll call the actual kernel using the
  // parameter signature it expects.
  builder_.CreateCall(main, args, "");
  // Free any temporary buffers we may have allocated.
  for (auto ptr : allocs) {
    Free(ptr);
  }
  builder_.CreateRetVoid();
}

uint64_t Compiler::MeasureArena(const stripe::Block& block) {
  // Look for refinements which have been placed into an arena.
  uint64_t extent = 0;
  for (const auto& ref : block.refs) {
    // skip refinements which lack the "placed" attribute applied by the placer
    if (!ref.has_tag("placed")) {
      continue;
    }
    // otherwise, look for refinements which are neither inputs nor outputs,
    // and which do not name a refinement in the context as "from"
    if (ref.dir == stripe::RefDir::None && ref.from.empty()) {
      extent = std::max(extent, ref.offset + ref.interior_shape.byte_size());
    }
  }
  // Scan any nested blocks for additional refinements.
  for (const auto& stmt : block.stmts) {
    if (auto inner = stripe::Block::Downcast(stmt)) {
      extent = std::max(extent, MeasureArena(*inner));
    }
  }
  return extent;
}

void Compiler::GenerateArena(const stripe::Block& block) {
  arenaSize_ = MeasureArena(block);
  auto arenatype = llvm::ArrayType::get(builder_.getInt8Ty(), 1)->getPointerTo();
  module_->getOrInsertGlobal(arena_name_, arenatype);
  auto gval = module_->getNamedGlobal(arena_name_);
  gval->setInitializer(llvm::Constant::getNullValue(arenatype));
}

llvm::Function* Compiler::CompileXSMMBlock(const stripe::Block& block, const XSMMDispatch xsmmDispatch,
                                           const XSMMCallData& xsmmCallData) {
  // Validate incoming params.
  if (xsmmCallData.in0 == nullptr || xsmmCallData.in1 == nullptr || xsmmCallData.out0 == nullptr ||
      xsmmCallData.lda_a_value == 0 || xsmmCallData.lda_b_value == 0 || xsmmCallData.lda_c_value == 0) {
    throw std::runtime_error("Invalid xsmmCallData state.");
  }

  // Generate a function that implements the body for this block of statements.
  // Refinements (their buffers) and initial indexes
  // will be passed as parameters (to the function).
  for (const auto& ref : block.refs) {
    buffers_[ref.into()] = Buffer{&ref};
  }
  for (const auto& idx : block.idxs) {
    indexes_[idx.name] = Index{&idx};
  }

  // Create the LLVM function which will implement the Stripe block
  auto linkage = llvm::Function::ExternalLinkage;
  auto name = block.name;
  auto func_type = BlockType(block);
  auto function = llvm::Function::Create(func_type, linkage, name, module_);

  // Areate a basic block; configure the builder to start there
  auto bb = llvm::BasicBlock::Create(context_, "entry", function);
  builder_.SetInsertPoint(bb);

  // Associate parameter values with buffers and indexes
  for (auto ai = function->arg_begin(); ai != function->arg_end(); ++ai) {
    unsigned idx = ai->getArgNo();
    if (idx < block.refs.size()) {
      auto it = block.refs.begin();
      std::advance(it, idx);
      std::string param_name = it->into();
      ai->setName(param_name);
      assert(nullptr == buffers_[param_name].base);
      buffers_[param_name].base = &(*ai);
    } else {
      idx -= block.refs.size();
      std::string param_name = block.idxs[idx].name;
      ai->setName(param_name);
      assert(nullptr == indexes_[param_name].init);
      indexes_[param_name].init = &(*ai);
    }
  }
  auto i32t = builder_.getInt32Ty();
  llvm::Value* lda = builder_.CreateAlloca(i32t);
  llvm::Value* ldb = builder_.CreateAlloca(i32t);
  llvm::Value* ldc = builder_.CreateAlloca(i32t);

  llvm::Value* alpha = nullptr;
  llvm::Value* beta = nullptr;
  llvm::Value* one = nullptr;
  llvm::Type* alphaPtrType = nullptr;
  llvm::Type* betaPtrType = nullptr;
  llvm::Type* aPtrType = nullptr;
  llvm::Type* bPtrType = nullptr;
  llvm::Type* cPtrType = nullptr;
  std::string functionName("Invalid");
  switch (xsmmDispatch) {
    case XSMMDispatch::SMM:
      one = llvm::ConstantFP::get(builder_.getFloatTy(), 1.0);
      alpha = builder_.CreateAlloca(builder_.getFloatTy());
      beta = builder_.CreateAlloca(builder_.getFloatTy());
      alphaPtrType = betaPtrType = aPtrType = bPtrType = cPtrType = llvm::Type::getFloatPtrTy(context_);
      functionName = "libxsmm_smmdispatch";
      break;

    case XSMMDispatch::DMM:
      one = llvm::ConstantFP::get(builder_.getDoubleTy(), 1.0L);
      alpha = builder_.CreateAlloca(builder_.getDoubleTy());
      beta = builder_.CreateAlloca(builder_.getDoubleTy());
      alphaPtrType = betaPtrType = aPtrType = bPtrType = cPtrType = llvm::Type::getDoublePtrTy(context_);
      functionName = "libxsmm_dmmdispatch";
      break;

    case XSMMDispatch::WIMM:
      one = llvm::ConstantInt::get(builder_.getInt8Ty(), 1);
      alpha = builder_.CreateAlloca(builder_.getInt8Ty());
      beta = builder_.CreateAlloca(builder_.getInt8Ty());
      alphaPtrType = betaPtrType = aPtrType = bPtrType = llvm::Type::getInt8PtrTy(context_);
      cPtrType = llvm::Type::getInt32PtrTy(context_);
      functionName = "libxsmm_wimmdispatch";
      break;

      // TODO: Handle the bfloat16 dispatch function when support for bfloat16 is added to Stripe.

    default:
      throw std::runtime_error("Unsupported DataType for XSMM.");
  }

  builder_.CreateStore(one, alpha);
  builder_.CreateStore(one, beta);

  builder_.CreateStore(llvm::ConstantInt::get(i32t, xsmmCallData.lda_a_value), lda);
  builder_.CreateStore(llvm::ConstantInt::get(i32t, xsmmCallData.lda_b_value), ldb);
  builder_.CreateStore(llvm::ConstantInt::get(i32t, xsmmCallData.lda_c_value), ldc);
  llvm::Value* nptr = llvm::ConstantPointerNull::get(llvm::Type::getInt32PtrTy(context_));
  llvm::Value* dispatch = XSMMDispatchFunction(alphaPtrType, betaPtrType, aPtrType, bPtrType, cPtrType, functionName);

  std::vector<llvm::Value*> args1 = {llvm::ConstantInt::get(i32t, FindIndexByTag(block, "stencil_m")->range),
                                     llvm::ConstantInt::get(i32t, FindIndexByTag(block, "stencil_n")->range),
                                     llvm::ConstantInt::get(i32t, FindIndexByTag(block, "stencil_k")->range),
                                     lda,
                                     ldb,
                                     ldc,
                                     alpha,
                                     beta,
                                     nptr,
                                     nptr};
  llvm::Value* func = builder_.CreateCall(dispatch, args1);

  std::vector<llvm::Type*> param_types{
      func->getType(),                                     // ptr of function to call
      buffers_[xsmmCallData.in1->into()].base->getType(),  // a
      buffers_[xsmmCallData.in0->into()].base->getType(),  // b
      buffers_[xsmmCallData.out0->into()].base->getType()  // c
  };
  llvm::FunctionType* rftype = llvm::FunctionType::get(builder_.getVoidTy(), param_types, false);
  llvm::FunctionType* xsmmCallHelperType = llvm::FunctionType::get(builder_.getVoidTy(), param_types, false);
  auto xmmCallFunc = module_->getOrInsertFunction("XSMMRTCaller", xsmmCallHelperType).getCallee();

#define CREATE_OFFSET_STMTS(name)                                                          \
  llvm::Value* arg_##name;                                                                 \
  if (xsmmCallData.offset_##name == 0) {                                                   \
    arg_##name = buffers_[xsmmCallData.name->into()].base;                                 \
  } else {                                                                                 \
    llvm::Value* idx_list[1] = {llvm::ConstantInt::get(i32t, xsmmCallData.offset_##name)}; \
    arg_##name = builder_.CreateGEP(buffers_[xsmmCallData.name->into()].base, idx_list);   \
  }
  CREATE_OFFSET_STMTS(in0);
  CREATE_OFFSET_STMTS(in1);
  CREATE_OFFSET_STMTS(out0);

  std::vector<llvm::Value*> args2 = {func, arg_in1, arg_in0, arg_out0};
  builder_.CreateCall(rftype, xmmCallFunc, args2);
  builder_.CreateRetVoid();
  return function;
}

// Gets the leading dimensions and the buffers for an XSMM call if available.
// @returns true if the XSMM call is applicable, otherwise false.
bool Compiler::GetXSMMCallData(XSMMCallData* xsmmCallData, const stripe::Block& block) {
  std::string m_name;
  std::string n_name;
  std::string k_name;

  uint32_t found = 0;
  for (const auto& idx : block.idxs) {
    if (idx.has_tag("stencil_m")) {
      found++;
      m_name = idx.name;
    }
    if (idx.has_tag("stencil_n")) {
      found++;
      n_name = idx.name;
    }
    if (idx.has_tag("stencil_k")) {
      found++;
      k_name = idx.name;
    }
  }

  // Special case n-dimension is 1 and an earlier optimization removes the stencil for n (being 1).
  bool useSpecialN = false;
  if (found == 2 && n_name.empty()) {
    useSpecialN = true;
    found++;
  }

  if (found != 3) {
    return false;
  }

  for (const auto& ref : block.refs) {
    if (ref.has_tag("A")) {
      auto flat = ref.FlatAccess();
      if (useSpecialN) {
        xsmmCallData->lda_b_value = 0;
      } else {
        xsmmCallData->lda_b_value = flat[n_name];
      }
      xsmmCallData->in0 = &ref;
      xsmmCallData->offset_in0 = flat.constant();
    } else if (ref.has_tag("B")) {
      auto flat = ref.FlatAccess();
      xsmmCallData->lda_a_value = flat[k_name];
      xsmmCallData->in1 = &ref;
      xsmmCallData->offset_in1 = flat.constant();
    } else if (ref.has_tag("C")) {
      auto flat = ref.FlatAccess();
      xsmmCallData->lda_c_value = flat[n_name];
      xsmmCallData->out0 = &ref;
      xsmmCallData->offset_out0 = flat.constant();
    }
  }

  if (xsmmCallData->in0 == nullptr || xsmmCallData->in1 == nullptr || xsmmCallData->out0 == nullptr ||
      xsmmCallData->lda_a_value == 0 || xsmmCallData->lda_b_value == 0 || xsmmCallData->lda_c_value == 0) {
    return false;
  }

  return true;
}

// Make sure all the refinments of this block are of the same type.
// If they are not, XSMM functions can't be called and we should
// to slower GEMM calculation process.
const XSMMDispatch Compiler::GetXSMMDispatch(const stripe::Block& block) {
  DataType dataType = DataType::INVALID;
  XSMMDispatch xsmmDispatch = XSMMDispatch::NONE;
  // First Check to see if all parameters and output are float32 or float64.
  bool firstIteration = true;
  const auto allRefs = block.refs;
  for (auto it = allRefs.cbegin(); it != allRefs.cend(); ++it) {
    if (firstIteration) {
      dataType = it->interior_shape.type;
      firstIteration = false;
      if (dataType == DataType::FLOAT32) {
        xsmmDispatch = XSMMDispatch::SMM;
      } else if (dataType == DataType::FLOAT32) {
        xsmmDispatch = XSMMDispatch::DMM;
      } else {
        break;
      }
    } else {
      if (dataType != it->interior_shape.type) {
        // Refinments with tdifferent DataType detected.
        // Return INVALID, so the XSMM logic detects XSMM
        // should not be used.
        dataType = DataType::INVALID;
        xsmmDispatch = XSMMDispatch::NONE;
        break;
      }
    }
  }

  // Now check for the WIMM, BSMM, and BMMM XSMM dispatch functions.
  if (xsmmDispatch == XSMMDispatch::NONE) {
    DataType in0 = DataType::INVALID;
    DataType in1 = DataType::INVALID;
    DataType out = DataType::INVALID;
    if (block.refs.size() == 3) {  // Only three refintments.
      for (const auto& ref : block.refs) {
        if (ref.has_tag("A")) {
          in0 = ref.interior_shape.type;
        } else if (ref.has_tag("B")) {
          in1 = ref.interior_shape.type;
        } else if (ref.has_tag("C")) {
          out = ref.interior_shape.type;
        }
      }

      // WIMM
      if (in0 == DataType::INT8 && in1 == DataType::UINT8 && out == DataType::INT32) {
        xsmmDispatch = XSMMDispatch::WIMM;
      }

      // TODO: Handle bfloat16 when support added to Stripe.
    }
  }

  return xsmmDispatch;
}

llvm::Function* Compiler::CompileThreadedBlock(const stripe::Block& block) {
  // Generate a function implementing the body of this block.
  // Buffers (refinements) will be passed in as function parameters, as will
  // the initial value for each index.

  for (const auto& ref : block.refs) {
    buffers_[ref.into()] = Buffer{&ref};
  }
  for (const auto& idx : block.idxs) {
    indexes_[idx.name] = Index{&idx};
  }

  // create the LLVM function which will implement the Stripe block
  auto linkage = llvm::Function::ExternalLinkage;
  auto name = block.name;
  auto func_type = BlockType(block);
  auto function = llvm::Function::Create(func_type, linkage, name, module_);
  // create a basic block; configure the builder to start there
  auto bb = llvm::BasicBlock::Create(context_, "entry", function);
  builder_.SetInsertPoint(bb);

  // This block will be invoked via ParallelFor, and its parameter signature
  // must match the cpu_thread_block type defined in executable.cc.
  assert(4 == function->arg_size());
  // First parameter is a pointer to an array of refinement pointers.
  // The type will always be int8_t**, so we must cast the pointer type.
  llvm::Value* refsArray = function->getArg(0);
  for (unsigned i = 0; i < block.refs.size(); ++i) {
    llvm::Value* refElement = builder_.CreateConstGEP1_32(refsArray, i);
    llvm::Value* refPtr = builder_.CreateLoad(refElement);
    auto it = block.refs.begin();
    std::advance(it, i);
    llvm::Type* buftype = CType(it->interior_shape.type)->getPointerTo();
    std::string refName = it->into();
    buffers_[refName].base = builder_.CreateBitCast(refPtr, buftype);
  }
  // Second parameter points to an array of index init values.
  llvm::Value* initsArray = function->getArg(1);
  for (unsigned i = 0; i < block.idxs.size(); ++i) {
    llvm::Value* idxElement = builder_.CreateConstGEP1_32(initsArray, i);
    llvm::Value* idxInit = builder_.CreateLoad(idxElement);
    std::string idxName = block.idxs[i].name;
    indexes_[idxName].init = idxInit;
    llvm::Value* idx_ptr = builder_.CreateAlloca(IndexType());
    indexes_[idxName].variable = idx_ptr;
  }
  // Third parameter is the composite index range begin.
  // Fourth parameter is the composite index range end.
  // We will use these to replace the init & limit values for index 0.

  // Construct the joint index loop
  Loop joint_loop;
  llvm::Value* joint_idx = builder_.CreateAlloca(IndexType());
  CreateLoop(&joint_loop, "joint_idx");
  EnterLoop(&joint_loop, joint_idx, function->getArg(2), function->getArg(3));

  // Extract into specific index values
  llvm::Value* cur = builder_.CreateLoad(joint_idx);
  for (auto& idx : block.idxs) {
    auto low_part = builder_.CreateURem(cur, IndexConst(idx.range));
    auto with_init = builder_.CreateAdd(low_part, indexes_[idx.name].init);
    cur = builder_.CreateUDiv(cur, IndexConst(idx.range));
    builder_.CreateStore(with_init, indexes_[idx.name].variable);
  }

  // check the constraints against the current index values and decide whether
  // to execute the block body for this iteration
  llvm::Value* go = builder_.getTrue();
  for (auto& constraint : block.constraints) {
    llvm::Value* gateval = Eval(constraint);
    llvm::Value* check = builder_.CreateICmpSGE(gateval, IndexConst(0));
    go = builder_.CreateAnd(check, go);
  }
  auto block_body = llvm::BasicBlock::Create(context_, "block", function);
  auto block_done = llvm::BasicBlock::Create(context_, "next", function);
  builder_.CreateCondBr(go, block_body, block_done);
  builder_.SetInsertPoint(block_body);

  // process each statement in the block body, generating code to modify the
  // parameter buffer contents
  std::shared_ptr<stripe::Block> pBlock = std::make_shared<stripe::Block>(block);
  for (const auto& stmt : block.stmts) {
    stmt->Accept(this);
  }

  // rejoin instruction flow after the constraint check
  builder_.CreateBr(block_done);
  builder_.SetInsertPoint(block_done);

  // increment each index, from innermost to outermost, then jump back to test
  LeaveLoop(&joint_loop, joint_idx);

  builder_.CreateRetVoid();
  return function;
}

llvm::Function* Compiler::CompileBlock(const stripe::Block& block) {
  CompileFor compileFor = getCompileFor(block);
  if (compileFor == XSMM_BLOCK) {
    assert(!block.has_tag("cpu_thread"));
    const XSMMDispatch xsmmDispatch = GetXSMMDispatch(block);
    XSMMCallData xsmmCallData;
    auto data = GetXSMMCallData(&xsmmCallData, block);
    if (xsmmDispatch != XSMMDispatch::NONE && data) {
      return CompileXSMMBlock(block, xsmmDispatch, xsmmCallData);
    }
  } else if (compileFor == THREADED_BLOCK) {
    return CompileThreadedBlock(block);
  }

  // Generate a function implementing the body of this block.
  // Buffers (refinements) will be passed in as function parameters, as will
  // the initial value for each index.

  for (const auto& ref : block.refs) {
    buffers_[ref.into()] = Buffer{&ref};
  }
  for (const auto& idx : block.idxs) {
    indexes_[idx.name] = Index{&idx};
  }

  // create the LLVM function which will implement the Stripe block
  auto linkage = llvm::Function::ExternalLinkage;
  auto name = block.name;
  auto func_type = BlockType(block);
  auto function = llvm::Function::Create(func_type, linkage, name, module_);
  // create a basic block; configure the builder to start there
  auto bb = llvm::BasicBlock::Create(context_, "entry", function);
  builder_.SetInsertPoint(bb);

  // This block will be invoked through a direct function call.
  // First, a parameter for each refinement, containing the base address.
  // Then, a parameter for each index, containing the initial value.
  for (auto ai = function->arg_begin(); ai != function->arg_end(); ++ai) {
    unsigned idx = ai->getArgNo();
    if (idx < block.refs.size()) {
      auto it = block.refs.begin();
      std::advance(it, idx);
      std::string param_name = it->into();
      ai->setName(param_name);
      assert(nullptr == buffers_[param_name].base);
      buffers_[param_name].base = &(*ai);
    } else {
      idx -= block.refs.size();
      std::string param_name = block.idxs[idx].name;
      ai->setName(param_name);
      assert(nullptr == indexes_[param_name].init);
      indexes_[param_name].init = &(*ai);
    }
  }

  // allocate storage for each loop index
  for (auto& idx : block.idxs) {
    llvm::Value* variable = builder_.CreateAlloca(IndexType());
    variable->setName(idx.name);
    assert(nullptr == indexes_[idx.name].variable);
    indexes_[idx.name].variable = variable;
  }

  // compute the limit value for each loop
  std::vector<llvm::Value*> limits(block.idxs.size());
  for (size_t i = 0; i < block.idxs.size(); ++i) {
    llvm::Value* init = indexes_[block.idxs[i].name].init;
    llvm::Value* range = IndexConst(block.idxs[i].range);
    limits[i] = builder_.CreateAdd(init, range);
  }

  // generate the basic blocks for each nested loop's evaluation stages
  // initialize each loop index and generate the termination check
  std::vector<Loop> loops(block.idxs.size());
  for (size_t i = 0; i < block.idxs.size(); ++i) {
    std::string name = block.idxs[i].name;
    llvm::Value* variable = indexes_[name].variable;
    llvm::Value* init = indexes_[name].init;
    CreateLoop(&loops[i], name);
    EnterLoop(&loops[i], variable, init, limits[i]);
  }

  // check the constraints against the current index values and decide whether
  // to execute the block body for this iteration
  llvm::Value* go = builder_.getTrue();
  for (auto& constraint : block.constraints) {
    llvm::Value* gateval = Eval(constraint);
    llvm::Value* check = builder_.CreateICmpSGE(gateval, IndexConst(0));
    go = builder_.CreateAnd(check, go);
  }
  auto block_body = llvm::BasicBlock::Create(context_, "block", function);
  auto block_done = llvm::BasicBlock::Create(context_, "next", function);
  builder_.CreateCondBr(go, block_body, block_done);
  builder_.SetInsertPoint(block_body);

  ProfileLoopEnter(block);

  // process each statement in the block body, generating code to modify the
  // parameter buffer contents
  std::shared_ptr<stripe::Block> pBlock = std::make_shared<stripe::Block>(block);
  for (const auto& stmt : block.stmts) {
    stmt->Accept(this);
  }

  ProfileLoopLeave(block);

  // rejoin instruction flow after the constraint check
  builder_.CreateBr(block_done);
  builder_.SetInsertPoint(block_done);

  // increment each index, from innermost to outermost, then jump back to test
  for (size_t i = block.idxs.size(); i-- > 0;) {
    llvm::Value* variable = indexes_[block.idxs[i].name].variable;
    LeaveLoop(&loops[i], variable);
  }

  builder_.CreateRetVoid();
  return function;
}

void Compiler::Visit(const stripe::Load& load) {
  // op->from is the name of a source buffer
  // op->into is the name of a destination scalar
  Buffer from = buffers_[load.from];
  // Look up the address of the target element.
  // Load the value from that address and use it to redefine the
  // destination scalar.
  llvm::Value* element = ElementPtr(from);
  llvm::Value* value = builder_.CreateLoad(element, load.into);
  scalars_[load.into] = Scalar{value, from.refinement->interior_shape.type};
}

void Compiler::Visit(const stripe::Store& store) {
  // op->from is the name of a source scalar
  // op->into is the name of a destination buffer
  // get the offset into the destination buffer from the scope context
  // look up the expected aggregation operation for the destination from the
  // context (assign, add, product/mul, min, max)
  // load the value to be stored from the source variable
  // use GEP to compute the destination element address
  // use the specified aggregation to store the value
  Buffer into = buffers_[store.into];
  Scalar from = Cast(scalars_[store.from], into.refinement->interior_shape.type);
  llvm::Value* value = from.value;
  llvm::Value* element = ElementPtr(into);
  std::string agg_op = into.refinement->agg_op;
  if ("add" == agg_op) {
    llvm::Value* prev = builder_.CreateLoad(element);
    if (is_float(from.type)) {
      value = builder_.CreateFAdd(value, prev);
    } else if (is_int(from.type) || is_uint(from.type)) {
      value = builder_.CreateAdd(value, prev);
    } else {
      throw Error("Invalid addition type: " + to_string(from.type));
    }
  } else if ("mul" == agg_op) {
    llvm::Value* prev = builder_.CreateLoad(element);
    if (is_float(from.type)) {
      value = builder_.CreateFMul(value, prev);
    } else if (is_int(from.type) || is_uint(from.type)) {
      value = builder_.CreateMul(value, prev);
    } else {
      throw Error("Invalid multiplication type: " + to_string(from.type));
    }
  } else if ("max" == agg_op) {
    llvm::Value* prev = builder_.CreateLoad(element);
    llvm::Value* flag = nullptr;
    if (is_float(from.type)) {
      flag = builder_.CreateFCmpUGT(prev, value);
    } else if (is_int(from.type)) {
      flag = builder_.CreateICmpSGT(prev, value);
    } else if (is_uint(from.type)) {
      flag = builder_.CreateICmpUGT(prev, value);
    }
    value = builder_.CreateSelect(flag, prev, value);
  } else if ("min" == agg_op) {
    llvm::Value* prev = builder_.CreateLoad(element);
    llvm::Value* flag = nullptr;
    if (is_float(from.type)) {
      flag = builder_.CreateFCmpULT(prev, value);
    } else if (is_int(from.type)) {
      flag = builder_.CreateICmpSLT(prev, value);
    } else if (is_uint(from.type)) {
      flag = builder_.CreateICmpULT(prev, value);
    }
    value = builder_.CreateSelect(flag, prev, value);
  } else if ("assign" == agg_op) {
    // fall through to assignment
  } else if (!agg_op.empty()) {
    throw Error("Unimplemented agg_op: " + to_string(agg_op));
  }
  builder_.CreateStore(value, element);
}

void Compiler::Visit(const stripe::LoadIndex& load_index) {
  // op->from is an affine
  // op->into is the name of a destination scalar
  llvm::Value* rval = Eval(load_index.from);
  rval->setName(load_index.into);
  scalars_[load_index.into] = Scalar{rval, DataType::INT64};
}

void Compiler::Visit(const stripe::Constant& constant) {
  // store a constant integer or float value into a scalar
  switch (constant.type) {
    case stripe::ConstType::Integer: {
      auto ty = builder_.getInt64Ty();
      auto value = llvm::ConstantInt::get(ty, constant.iconst);
      scalars_[constant.name] = Scalar{value, DataType::INT64};
      value->setName(constant.name);
    } break;
    case stripe::ConstType::Float: {
      auto ty = builder_.getDoubleTy();
      auto value = llvm::ConstantFP::get(ty, constant.fconst);
      scalars_[constant.name] = Scalar{value, DataType::FLOAT64};
      value->setName(constant.name);
    } break;
  }
}

void Compiler::Visit(const stripe::Special& special) {
  // The list of specials defined in the spec differs from the list defined in
  // tile/lang/gen_special.cc. The spec lists "zero", "copy", and "reshape",
  // while gen_special.cc uses "gather", "scatter", "shape", and "prng_step".
  static std::map<std::string, std::function<void(Compiler*, const stripe::Special&)>> handlers{
      {"zero", &Compiler::Zero},                //
      {"copy", &Compiler::Copy},                //
      {"reshape", &Compiler::Reshape},          //
      {"prng_step", &Compiler::PrngStep},       //
      {"shape", &Compiler::Shape},              //
      {"agg_init_add", &Compiler::AggInitAdd},  //
      {"agg_init_mul", &Compiler::AggInitMul},  //
      {"agg_init_min", &Compiler::AggInitMin},  //
      {"agg_init_max", &Compiler::AggInitMax},  //
      {"scatter", &Compiler::Scatter},          //
      {"gather", &Compiler::Gather},            //
  };
  auto it = handlers.find(special.name);
  if (it == handlers.end()) {
    throw Error("Unknown special \"" + special.name + "\"");
  }
  it->second(this, special);
}

void Compiler::Visit(const stripe::Intrinsic& intrinsic) {
  // Find the correct handler for this intrinsic.
  // If the context has provided an external handler for this intrinsic name,
  // we'll use it - that allows the context to override builtin definitions.
  // If there is no external handler, look up a builtin handler definition.
  // Note that stripe::Intrinsic defines a bunch of strings which are not
  // actually intrinsics; they are values for stripe::Refinement::agg_op. Not
  // sure why they are located in the wrong structure. There are no constants
  // defined for actual intrinsic names; these have been derived experimentally.
  static std::map<std::string, std::function<void(Compiler*, const stripe::Intrinsic&)>> builtins{
      {"add", &Compiler::Add},
      {"assign", &Compiler::Assign},
      {"cond", &Compiler::Conditional},
      {"div", &Compiler::Divide},
      {"eq", &Compiler::Equal},
      {"gt", &Compiler::GreaterThan},
      {"gte", &Compiler::GreaterThanOrEqualTo},
      {"ident", &Compiler::Assign},
      {"lt", &Compiler::LessThan},
      {"lte", &Compiler::LessThanOrEqualTo},
      {"mod", &Compiler::Mod},
      {"mul", &Compiler::Multiply},
      {"neg", &Compiler::Negate},
      {"neq", &Compiler::Unequal},
      {"sub", &Compiler::Subtract},
      // Extra operations defined in tile/lang/ops.cc, which are apparently
      // passed along directly into Stripe
      {"bit_and", &Compiler::And},
      {"bit_or", &Compiler::Or},
      {"bit_not", &Compiler::Not},
      {"bit_left", &Compiler::BitLeft},
      {"bit_right", &Compiler::BitRight},
      {"bit_xor", &Compiler::Xor},
      {"cmp_eq", &Compiler::Equal},
      {"cmp_ne", &Compiler::Unequal},
      {"cmp_lt", &Compiler::LessThan},
      {"cmp_gt", &Compiler::GreaterThan},
      {"cmp_le", &Compiler::LessThanOrEqualTo},
      {"cmp_ge", &Compiler::GreaterThanOrEqualTo},
      // Other undocumented intrinsics, which are apparently necessary in order
      // to successfully run the backend_test:
      {"as_bool", &Compiler::AsBool},
      {"as_float", &Compiler::AsFloat},
      {"as_int", &Compiler::AsInt},
      {"as_uint", &Compiler::AsUInt},
      {"ceil", &Compiler::Ceil},
      {"cos", &Compiler::Cos},
      {"exp", &Compiler::Exp},
      {"floor", &Compiler::Floor},
      {"log", &Compiler::Log},
      {"pow", &Compiler::Pow},
      {"round", &Compiler::Round},
      {"sqrt", &Compiler::Sqrt},
      {"tanh", &Compiler::Tanh},
      // Numeric operations from stdlib mentioned in tile/lang/builtins.cc
      {"abs", &Compiler::Abs},
      {"acos", &Compiler::Acos},
      {"acosh", &Compiler::Acosh},
      {"asin", &Compiler::Asin},
      {"asinh", &Compiler::Asinh},
      {"atan", &Compiler::Atan},
      {"atanh", &Compiler::Atanh},
      {"cosh", &Compiler::Cosh},
      {"sin", &Compiler::Sin},
      {"sinh", &Compiler::Sinh},
      {"tan", &Compiler::Tan},
  };
  auto externiter = config_.externals.find(intrinsic.name);
  if (externiter != config_.externals.end()) {
    Intrinsic(intrinsic, externiter->second);
  } else {
    auto builtiniter = builtins.find(intrinsic.name);
    if (builtiniter == builtins.end()) {
      throw Error("Unknown intrinsic \"" + intrinsic.name + "\"");
    }
    builtiniter->second(this, intrinsic);
  }
}

void Compiler::Visit(const stripe::Block& block) {
  // Compile a nested block as a function in the same module
  Compiler nested(&context_, module_, config_);
  auto function = nested.CompileBlock(block);
  for (auto& fptr_iter : nested.external_funcptrs_) {
    external_funcptrs_.emplace(fptr_iter);
  }

  ProfileBlockEnter(block);
  // Generate a list of args.
  // The argument list begins with a pointer to each refinement. We will either
  // pass along the address of a refinement from the current block, or allocate
  // a new buffer for the nested block's use.
  std::vector<llvm::Value*> refs;
  std::vector<llvm::Value*> allocs;
  for (auto& ref : block.refs) {
    llvm::Value* buffer = nullptr;
    // When a refinement is neither in nor out, and it has no "from"
    // name, it represents a local allocation.
    if (ref.dir == stripe::RefDir::None && ref.from.empty()) {
      if (ref.has_tag("placed")) {
        auto arena = module_->getNamedGlobal(arena_name_);
        auto baseArenaAddress = builder_.CreateLoad(arena);
        std::vector<llvm::Value*> idxList{IndexConst(ref.offset)};
        buffer = builder_.CreateGEP(baseArenaAddress, idxList);
      } else {
        // Allocate new storage for the buffer.
        buffer = Malloc(ref.interior_shape.byte_size());
        allocs.push_back(buffer);
      }
      llvm::Type* buftype = CType(ref.interior_shape.type)->getPointerTo();
      buffer = builder_.CreateBitCast(buffer, buftype, ref.into());
    } else {
      // Pass in the current element address from the source buffer.
      // If a "from" name is specified, use that buffer; if not, that means
      // that both blocks use the same name, so use "into".
      std::string name = ref.from.empty() ? ref.into() : ref.from;
      buffer = ElementPtr(buffers_[name]);
    }
    refs.push_back(buffer);
  }
  // Following the list of refinement args, we will provide a list of initial
  // values for each of the block's indexes, which are specified as an affine
  // in terms of the current block's indexes.
  std::vector<llvm::Value*> idxs;
  for (auto& idx : block.idxs) {
    idxs.push_back(Eval(idx.affine));
  }

  // Assemble the argument list and invoke the function.
  if (getCompileFor(block) == THREADED_BLOCK) {
    assert(!block.has_tag("xsmm"));
    // Combine the bufs into an array; pass it as the first parameter.
    auto int8PtrType = builder_.getInt8Ty()->getPointerTo();
    auto int8PtrArrayType = llvm::ArrayType::get(int8PtrType, refs.size());
    llvm::Value* bufsArg = builder_.CreateAlloca(int8PtrArrayType);
    bufsArg = builder_.CreateBitCast(bufsArg, int8PtrType->getPointerTo());
    for (size_t i = 0; i < refs.size(); ++i) {
      llvm::Value* castRef = builder_.CreateBitCast(refs[i], int8PtrType);
      llvm::Value* elementPtr = builder_.CreateConstGEP1_32(bufsArg, i);
      builder_.CreateStore(castRef, elementPtr);
    }
    // Combine the idx inits into an array; pass it as the second parameter.
    auto indexArrayType = llvm::ArrayType::get(IndexType(), idxs.size());
    llvm::Value* initsArg = builder_.CreateAlloca(indexArrayType);
    initsArg = builder_.CreateBitCast(initsArg, IndexType()->getPointerTo());
    for (size_t i = 0; i < idxs.size(); ++i) {
      llvm::Value* elementPtr = builder_.CreateConstGEP1_32(initsArg, i);
      builder_.CreateStore(idxs[i], elementPtr);
    }
    if (!block.idxs.empty()) {
      size_t total_range = 1;
      for (auto& idx : block.idxs) {
        total_range *= idx.range;
      }
      ParallelFor(bufsArg, initsArg, total_range, function);
    } else {
      // There is no point in using ParallelFor to invoke a block which has no
      // indexes, since there is no way to divide the work among threads.
      auto zero = IndexConst(0);
      builder_.CreateCall(function, {bufsArg, initsArg, zero, zero});
    }
  } else {
    // Argument list consists of the refinements, followed by the index inits.
    std::vector<llvm::Value*> args;
    args.insert(args.end(), refs.begin(), refs.end());
    args.insert(args.end(), idxs.begin(), idxs.end());
    // Invoke the function. It does not return a value.
    builder_.CreateCall(function, args, "");
  }

  // Free the temporary buffers we allocated as parameter values.
  for (auto ptr : allocs) {
    Free(ptr);
  }

  ProfileBlockLeave(block);
}

void Compiler::Intrinsic(const stripe::Intrinsic& intrinsic, External handler) {
  // Process an intrinsic statement using an external handler function.
  // Load all the input scalars. Create a vector containing their types.
  // We will provide this as input to the handler, so the handler may perform
  // overloading if it so desires. The handler must replace the input types
  // with its own list of desired inputs. We will use this for argument count
  // verification, then cast each input scalar accordingly.
  std::vector<Scalar> inputs;
  std::vector<DataType> input_types;
  for (auto& input_name : intrinsic.inputs) {
    Scalar input = scalars_[input_name];
    inputs.push_back(input);
    input_types.push_back(input.type);
  }
  DataType output_type = intrinsic.type;
  // Call the handler. It will process the input and output types, then return
  // an entrypoint address. If it does not like the types, it may throw, or
  // return nullptr_t, in which case we will throw.
  auto funcptr = handler(&input_types, &output_type);
  if (!funcptr) {
    throw Error("External intrinsic rejected for " + intrinsic.name);
  }
  // Verify that we have the expected number of inputs. Cast each input to the
  // type specified by the handler.
  if (inputs.size() != input_types.size()) {
    throw Error("External intrinsic " + intrinsic.name + " expects " + std::to_string(input_types.size()) +
                " input(s), but the invocation " + "provided " + std::to_string(inputs.size()));
  }
  std::vector<llvm::Type*> argtypes(inputs.size());
  std::vector<llvm::Value*> argvals(inputs.size());
  for (size_t i = 0; i < inputs.size(); ++i) {
    inputs[i] = Cast(inputs[i], input_types[i]);
    argvals[i] = inputs[i].value;
    assert(argvals[i]);
    argtypes[i] = CType(inputs[i].type);
  }
  // Build a function type signature for this list of input and output types
  llvm::Type* rtype = CType(output_type);
  auto functype = llvm::FunctionType::get(rtype, argtypes, false);
  // llvm::Type* fptrtype = functype->getPointerTo();
  // Embed the funcptr as a constant, cast to the relevant function type
  // llvm::Value* funcval = llvm::ConstantInt::get(fptrtype, (intptr_t)funcptr);
  // Generate a call to the funcptr
  std::string funcname = "external_" + intrinsic.name;
  external_funcptrs_[funcname] = funcptr;
  auto funcval = module_->getOrInsertFunction(funcname.c_str(), functype).getCallee();
  auto ret = builder_.CreateCall(funcval, argvals, "");
  // If we have an output, store the new scalar value.
  size_t expected_outputs = 0;
  if (output_type != DataType::INVALID) {
    expected_outputs = 1;
    scalars_[intrinsic.outputs[0]] = Scalar{ret, output_type};
  }
  // Verify that we have the expected number of outputs.
  if (expected_outputs != intrinsic.outputs.size()) {
    throw Error("External intrinsic " + intrinsic.name + " expects " + std::to_string(expected_outputs) +
                " output(s), but the invocation " + "provided " + std::to_string(intrinsic.outputs.size()));
  }
}

void Compiler::Add(const stripe::Intrinsic& add) {
  // Accepts two inputs, cast to operation type
  assert(2 == add.inputs.size());
  Scalar lhs = Cast(scalars_[add.inputs[0]], add.type);
  Scalar rhs = Cast(scalars_[add.inputs[1]], add.type);
  // Sum placed into the single output
  // Output type is operation type
  llvm::Value* ret = nullptr;
  if (is_float(add.type)) {
    ret = builder_.CreateFAdd(lhs.value, rhs.value);
  } else if (is_int(add.type) || is_uint(add.type)) {
    ret = builder_.CreateAdd(lhs.value, rhs.value);
  } else {
    throw Error("Invalid addition type: " + to_string(add.type));
  }
  OutputType(ret, add);
}

void Compiler::Subtract(const stripe::Intrinsic& sub) {
  // Accepts two inputs, cast to operation type
  assert(2 == sub.inputs.size());
  Scalar lhs = Cast(scalars_[sub.inputs[0]], sub.type);
  Scalar rhs = Cast(scalars_[sub.inputs[1]], sub.type);
  // Difference placed into the single output
  // Output type is operation type
  llvm::Value* ret = nullptr;
  if (is_float(sub.type)) {
    ret = builder_.CreateFSub(lhs.value, rhs.value);
  } else if (is_int(sub.type) || is_uint(sub.type)) {
    ret = builder_.CreateSub(lhs.value, rhs.value);
  } else {
    throw Error("Invalid subtraction type: " + to_string(sub.type));
  }
  OutputType(ret, sub);
}

void Compiler::Negate(const stripe::Intrinsic& neg) {
  // Accepts one input
  assert(1 == neg.inputs.size());
  Scalar op = Cast(scalars_[neg.inputs[0]], neg.type);
  // Negated operand value placed into the single output
  // Output type is operation type
  llvm::Value* ret = nullptr;
  if (is_float(neg.type)) {
    ret = builder_.CreateFNeg(op.value);
  } else if (is_int(neg.type) || is_uint(neg.type)) {
    ret = builder_.CreateNeg(op.value);
  } else {
    throw Error("Invalid negation type: " + to_string(neg.type));
  }
  OutputType(ret, neg);
}

void Compiler::Multiply(const stripe::Intrinsic& mul) {
  // Accepts two inputs, cast to operation type
  assert(2 == mul.inputs.size());
  Scalar lhs = Cast(scalars_[mul.inputs[0]], mul.type);
  Scalar rhs = Cast(scalars_[mul.inputs[1]], mul.type);
  // Product placed into the single output
  // Output type is operation type
  llvm::Value* ret = nullptr;
  if (is_float(mul.type)) {
    ret = builder_.CreateFMul(lhs.value, rhs.value);
  } else if (is_int(mul.type) || is_uint(mul.type)) {
    ret = builder_.CreateMul(lhs.value, rhs.value);
  } else {
    throw Error("Invalid multiplication type: " + to_string(mul.type));
  }
  OutputType(ret, mul);
}

void Compiler::Divide(const stripe::Intrinsic& div) {
  // Accepts two inputs, cast to operation type
  assert(2 == div.inputs.size());
  Scalar lhs = Cast(scalars_[div.inputs[0]], div.type);
  Scalar rhs = Cast(scalars_[div.inputs[1]], div.type);
  // Product placed into the single output
  // Output type is operation type
  llvm::Value* ret = nullptr;
  if (is_float(div.type)) {
    ret = builder_.CreateFDiv(lhs.value, rhs.value);
  } else if (is_int(div.type)) {
    ret = builder_.CreateSDiv(lhs.value, rhs.value);
  } else if (is_uint(div.type)) {
    ret = builder_.CreateUDiv(lhs.value, rhs.value);
  } else {
    throw Error("Invalid division type: " + to_string(div.type));
  }
  OutputType(ret, div);
}

void Compiler::Mod(const stripe::Intrinsic& mod) {
  // Accepts two inputs, cast to operation type
  assert(2 == mod.inputs.size());
  Scalar lhs = Cast(scalars_[mod.inputs[0]], mod.type);
  Scalar rhs = Cast(scalars_[mod.inputs[1]], mod.type);
  // Product placed into the single output
  // Output type is operation type
  llvm::Value* ret = nullptr;
  if (is_int(mod.type)) {
    ret = builder_.CreateSRem(lhs.value, rhs.value);
  } else if (is_uint(mod.type)) {
    ret = builder_.CreateURem(lhs.value, rhs.value);
  } else {
    throw Error("Invalid modulo type: " + to_string(mod.type));
  }
  OutputType(ret, mod);
}

void Compiler::LessThan(const stripe::Intrinsic& lt) {
  // Accepts two inputs
  assert(2 == lt.inputs.size());
  Scalar lhs = Cast(scalars_[lt.inputs[0]], lt.type);
  Scalar rhs = Cast(scalars_[lt.inputs[1]], lt.type);
  // Inputs are cast to operation type
  // Equality placed into single output
  // Output type is boolean
  llvm::Value* ret = nullptr;
  if (is_float(lt.type)) {
    ret = builder_.CreateFCmpOLT(lhs.value, rhs.value);
  } else if (is_int(lt.type)) {
    ret = builder_.CreateICmpSLT(lhs.value, rhs.value);
  } else if (is_uint(lt.type)) {
    ret = builder_.CreateICmpULT(lhs.value, rhs.value);
  } else {
    throw Error("Invalid comparison type (LT): " + to_string(lt.type));
  }
  OutputBool(ret, lt);
}

void Compiler::LessThanOrEqualTo(const stripe::Intrinsic& lte) {
  // Accepts two inputs
  assert(2 == lte.inputs.size());
  Scalar lhs = Cast(scalars_[lte.inputs[0]], lte.type);
  Scalar rhs = Cast(scalars_[lte.inputs[1]], lte.type);
  // Inputs are cast to operation type
  // Equality placed into single output
  // Output type is boolean
  llvm::Value* ret = nullptr;
  if (is_float(lte.type)) {
    ret = builder_.CreateFCmpOLE(lhs.value, rhs.value);
  } else if (is_int(lte.type)) {
    ret = builder_.CreateICmpSLE(lhs.value, rhs.value);
  } else if (is_uint(lte.type)) {
    ret = builder_.CreateICmpULE(lhs.value, rhs.value);
  } else {
    throw Error("Invalid comparison type (LE): " + to_string(lte.type));
  }
  OutputBool(ret, lte);
}

void Compiler::GreaterThan(const stripe::Intrinsic& gt) {
  // Accepts two inputs
  assert(2 == gt.inputs.size());
  Scalar lhs = Cast(scalars_[gt.inputs[0]], gt.type);
  Scalar rhs = Cast(scalars_[gt.inputs[1]], gt.type);
  // Inputs are cast to operation type
  // Equality placed into single output
  // Output type is boolean
  llvm::Value* ret = nullptr;
  if (is_float(gt.type)) {
    ret = builder_.CreateFCmpOGT(lhs.value, rhs.value);
  } else if (is_int(gt.type)) {
    ret = builder_.CreateICmpSGT(lhs.value, rhs.value);
  } else if (is_uint(gt.type)) {
    ret = builder_.CreateICmpUGT(lhs.value, rhs.value);
  } else {
    throw Error("Invalid comparison type (GT): " + to_string(gt.type));
  }
  OutputBool(ret, gt);
}

void Compiler::GreaterThanOrEqualTo(const stripe::Intrinsic& gte) {
  // Accepts two inputs
  assert(2 == gte.inputs.size());
  Scalar lhs = Cast(scalars_[gte.inputs[0]], gte.type);
  Scalar rhs = Cast(scalars_[gte.inputs[1]], gte.type);
  // Inputs are cast to operation type
  // Equality placed into single output
  // Output type is boolean
  llvm::Value* ret = nullptr;
  if (is_float(gte.type)) {
    ret = builder_.CreateFCmpOGE(lhs.value, rhs.value);
  } else if (is_int(gte.type)) {
    ret = builder_.CreateICmpSGE(lhs.value, rhs.value);
  } else if (is_uint(gte.type)) {
    ret = builder_.CreateICmpUGE(lhs.value, rhs.value);
  } else {
    throw Error("Invalid comparison type (GE): " + to_string(gte.type));
  }
  OutputBool(ret, gte);
}

void Compiler::Equal(const stripe::Intrinsic& eq) {
  // Accepts two inputs
  assert(2 == eq.inputs.size());
  Scalar lhs = Cast(scalars_[eq.inputs[0]], eq.type);
  Scalar rhs = Cast(scalars_[eq.inputs[1]], eq.type);
  // Inputs are cast to operation type
  // Equality placed into single output
  // Output type is boolean
  llvm::Value* ret = nullptr;
  if (is_float(eq.type)) {
    ret = builder_.CreateFCmpOEQ(lhs.value, rhs.value);
  } else if (is_int(eq.type) || is_uint(eq.type)) {
    ret = builder_.CreateICmpEQ(lhs.value, rhs.value);
  } else if (DataType::BOOLEAN == eq.type) {
    ret = builder_.CreateICmpEQ(lhs.value, rhs.value);
  } else {
    throw Error("Invalid comparison type (EQ): " + to_string(eq.type));
  }
  OutputBool(ret, eq);
}

void Compiler::Unequal(const stripe::Intrinsic& neq) {
  // Accepts two inputs
  assert(2 == neq.inputs.size());
  Scalar lhs = Cast(scalars_[neq.inputs[0]], neq.type);
  Scalar rhs = Cast(scalars_[neq.inputs[1]], neq.type);
  // Inputs are cast to operation type
  // Equality placed into single output
  // Output type is boolean
  llvm::Value* ret = nullptr;
  if (is_float(neq.type)) {
    ret = builder_.CreateFCmpONE(lhs.value, rhs.value);
  } else if (is_int(neq.type) || is_uint(neq.type)) {
    ret = builder_.CreateICmpNE(lhs.value, rhs.value);
  } else if (DataType::BOOLEAN == neq.type) {
    ret = builder_.CreateICmpNE(lhs.value, rhs.value);
  } else {
    throw Error("Invalid comparison type (NE): " + to_string(neq.type));
  }
  OutputBool(ret, neq);
}

void Compiler::Conditional(const stripe::Intrinsic& cond) {
  // Three inputs: C, T, F; C is boolean, T and F are operation type
  assert(3 == cond.inputs.size());
  // There are cases where keras calls a conditional with fp32 and fp64 type
  // for first parameter. Cast it to boolean, so the LLVM type system checks
  // are satisfied.
  Scalar c = Cast(scalars_[cond.inputs[0]], DataType::BOOLEAN);
  Scalar t = Cast(scalars_[cond.inputs[1]], cond.type);
  Scalar f = Cast(scalars_[cond.inputs[2]], cond.type);
  // Single output will be one of T or F
  // Output type is operation type
  llvm::Value* ret = builder_.CreateSelect(c.value, t.value, f.value);
  OutputType(ret, cond);
}

void Compiler::And(const stripe::Intrinsic& stmt) {
  assert(2 == stmt.inputs.size());
  Scalar lhs = CheckNotFloat(scalars_[stmt.inputs[0]]);
  Scalar rhs = CheckNotFloat(scalars_[stmt.inputs[1]]);
  llvm::Value* ret = builder_.CreateAnd(lhs.value, rhs.value);
  OutputBool(ret, stmt);
}

void Compiler::Or(const stripe::Intrinsic& stmt) {
  assert(2 == stmt.inputs.size());
  Scalar lhs = CheckNotFloat(scalars_[stmt.inputs[0]]);
  Scalar rhs = CheckNotFloat(scalars_[stmt.inputs[1]]);
  llvm::Value* ret = builder_.CreateOr(lhs.value, rhs.value);
  OutputBool(ret, stmt);
}

void Compiler::Not(const stripe::Intrinsic& stmt) {
  assert(1 == stmt.inputs.size());
  Scalar op = CheckNotFloat(scalars_[stmt.inputs[0]]);
  llvm::Value* ret = builder_.CreateNot(op.value);
  OutputBool(ret, stmt);
}

void Compiler::Xor(const stripe::Intrinsic& stmt) {
  assert(2 == stmt.inputs.size());
  Scalar lhs = CheckNotFloat(scalars_[stmt.inputs[0]]);
  Scalar rhs = CheckNotFloat(scalars_[stmt.inputs[1]]);
  llvm::Value* ret = builder_.CreateXor(lhs.value, rhs.value);
  OutputBool(ret, stmt);
}

void Compiler::Assign(const stripe::Intrinsic& stmt) {
  assert(1 == stmt.inputs.size());
  Scalar op = Cast(scalars_[stmt.inputs[0]], stmt.type);
  llvm::Value* ret = op.value;
  OutputType(ret, stmt);
}

void Compiler::BitLeft(const stripe::Intrinsic& stmt) {
  assert(2 == stmt.inputs.size());
  Scalar lhs = Cast(scalars_[stmt.inputs[0]], stmt.type);
  Scalar rhs = Cast(scalars_[stmt.inputs[1]], stmt.type);
  llvm::Value* ret = builder_.CreateShl(lhs.value, rhs.value);
  OutputType(ret, stmt);
}

void Compiler::AsFloat(const stripe::Intrinsic& stmt) {
  assert(2 == stmt.inputs.size());
  Scalar inStmt2 = scalars_[stmt.inputs[1]];
  int bits = llvm::cast<llvm::ConstantInt>(*inStmt2.value).getValue().getLimitedValue();
  DataType type = DataType::INVALID;
  switch (bits) {
    case 32:
      type = DataType::FLOAT32;
      break;
    case 64:
      type = DataType::FLOAT64;
      break;
    default:
      // TODO: Add bfloat16 when added to Stripe.
      std::ostringstream oss;
      oss << "Invalid bit count for as_float for CPU jit - " << bits;
      throw std::runtime_error(oss.str());
  }

  Scalar ret = Cast(scalars_[stmt.inputs[0]], type);
  assert(1 == stmt.outputs.size());
  scalars_[stmt.outputs[0]] = ret;
  ret.value->setName(stmt.outputs[0]);
}

void Compiler::AsInt(const stripe::Intrinsic& stmt) {
  assert(2 == stmt.inputs.size());
  Scalar inStmt2 = scalars_[stmt.inputs[1]];
  int bits = llvm::cast<llvm::ConstantInt>(*inStmt2.value).getValue().getLimitedValue();
  DataType type = DataType::INVALID;
  switch (bits) {
    case 8:
      type = DataType::INT8;
      break;
    case 16:
      type = DataType::INT16;
      break;
    case 32:
      type = DataType::INT32;
      break;
    case 64:
      type = DataType::INT64;
      break;
    default:
      std::ostringstream oss;
      oss << "Invalid bit count for as_int for CPU jit - " << bits;
      throw std::runtime_error(oss.str());
  }

  Scalar ret = Cast(scalars_[stmt.inputs[0]], type);
  assert(1 == stmt.outputs.size());
  scalars_[stmt.outputs[0]] = ret;
  ret.value->setName(stmt.outputs[0]);
}

void Compiler::AsUInt(const stripe::Intrinsic& stmt) {
  assert(2 == stmt.inputs.size());
  Scalar inStmt2 = scalars_[stmt.inputs[1]];
  int bits = llvm::cast<llvm::ConstantInt>(*inStmt2.value).getValue().getLimitedValue();
  DataType type = DataType::INVALID;
  switch (bits) {
    case 8:
      type = DataType::UINT8;
      break;
    case 16:
      type = DataType::UINT16;
      break;
    case 32:
      type = DataType::UINT32;
      break;
    case 64:
      type = DataType::UINT64;
      break;
    default:
      std::ostringstream oss;
      oss << "Invalid bit count for as_uint for CPU jit - " << bits;
      throw std::runtime_error(oss.str());
  }

  Scalar ret = Cast(scalars_[stmt.inputs[0]], type);
  assert(1 == stmt.outputs.size());
  scalars_[stmt.outputs[0]] = ret;
  ret.value->setName(stmt.outputs[0]);
}

void Compiler::AsBool(const stripe::Intrinsic& stmt) {
  assert(1 == stmt.inputs.size());
  Scalar ret = Cast(scalars_[stmt.inputs[0]], DataType::BOOLEAN);
  assert(1 == stmt.outputs.size());
  scalars_[stmt.outputs[0]] = ret;
  ret.value->setName(stmt.outputs[0]);
}

void Compiler::BitRight(const stripe::Intrinsic& stmt) {
  assert(2 == stmt.inputs.size());
  Scalar lhs = Cast(scalars_[stmt.inputs[0]], stmt.type);
  Scalar rhs = Cast(scalars_[stmt.inputs[1]], stmt.type);
  llvm::Value* ret = nullptr;
  if (is_int(stmt.type)) {
    ret = builder_.CreateAShr(lhs.value, rhs.value);
  } else if (is_uint(stmt.type)) {
    ret = builder_.CreateLShr(lhs.value, rhs.value);
  } else {
    throw Error("Invalid bitshift type: " + to_string(stmt.type));
  }
  OutputType(ret, stmt);
}

void Compiler::Sqrt(const stripe::Intrinsic& stmt) { CallIntrinsicFunc(stmt, "sqrtf", "sqrt"); }

void Compiler::Exp(const stripe::Intrinsic& stmt) { CallIntrinsicFunc(stmt, "expf", "exp"); }

void Compiler::Log(const stripe::Intrinsic& stmt) { CallIntrinsicFunc(stmt, "logf", "log"); }

void Compiler::Pow(const stripe::Intrinsic& stmt) { CallIntrinsicFunc(stmt, "powf", "pow", 2); }

void Compiler::Tanh(const stripe::Intrinsic& stmt) { CallIntrinsicFunc(stmt, "tanhf", "tanh"); }

void Compiler::Cos(const stripe::Intrinsic& stmt) { CallIntrinsicFunc(stmt, "cosf", "cos"); }

void Compiler::Floor(const stripe::Intrinsic& stmt) { CallIntrinsicFunc(stmt, "floorf", "floor"); }

void Compiler::Ceil(const stripe::Intrinsic& stmt) { CallIntrinsicFunc(stmt, "ceilf", "ceil"); }

void Compiler::Round(const stripe::Intrinsic& stmt) { CallIntrinsicFunc(stmt, "roundf", "round"); }

void Compiler::Abs(const stripe::Intrinsic& stmt) { CallIntrinsicFunc(stmt, "absf", "abs"); }

void Compiler::Acos(const stripe::Intrinsic& stmt) { CallIntrinsicFunc(stmt, "acosf", "acos"); }

void Compiler::Asin(const stripe::Intrinsic& stmt) { CallIntrinsicFunc(stmt, "asinf", "asin"); }

void Compiler::Atan(const stripe::Intrinsic& stmt) { CallIntrinsicFunc(stmt, "atanf", "atan"); }

void Compiler::Acosh(const stripe::Intrinsic& stmt) { CallIntrinsicFunc(stmt, "acoshf", "acosh"); }

void Compiler::Asinh(const stripe::Intrinsic& stmt) { CallIntrinsicFunc(stmt, "asinhf", "asinh"); }

void Compiler::Atanh(const stripe::Intrinsic& stmt) { CallIntrinsicFunc(stmt, "atanhf", "atanh"); }

void Compiler::Cosh(const stripe::Intrinsic& stmt) { CallIntrinsicFunc(stmt, "coshf", "cosh"); }

void Compiler::Sin(const stripe::Intrinsic& stmt) { CallIntrinsicFunc(stmt, "sinf", "sin"); }

void Compiler::Sinh(const stripe::Intrinsic& stmt) { CallIntrinsicFunc(stmt, "sinhf", "sinh"); }

void Compiler::Tan(const stripe::Intrinsic& stmt) { CallIntrinsicFunc(stmt, "tanf", "tan"); }

void Compiler::Zero(const stripe::Special& zero) {
  assert(0 == zero.inputs.size());
  assert(1 == zero.outputs.size());
  Buffer dst = buffers_[zero.outputs[0]];
  auto size = dst.refinement->interior_shape.byte_size();
  builder_.CreateMemSet(dst.base, builder_.getInt8(0), size, llvm::MaybeAlign(0));
}

void Compiler::Copy(const stripe::Special& copy) {
  // present in stripe.proto but not defined in the specification
  throw Error("Special operation COPY is not yet specified");
}

void Compiler::Reshape(const stripe::Special& reshape) {
  assert(1 == reshape.inputs.size());
  Buffer src = buffers_[reshape.inputs[0]];
  assert(1 == reshape.outputs.size());
  Buffer dst = buffers_[reshape.outputs[0]];
  auto size = IndexConst(dst.refinement->interior_shape.byte_size());
  builder_.CreateMemCpy(dst.base, llvm::MaybeAlign(0), src.base, llvm::MaybeAlign(0), size);
}

void Compiler::PrngStep(const stripe::Special& prng_step) {
  // Input is a matrix of 3xN containing PRNG state.
  assert(1 == prng_step.inputs.size());
  Buffer in_state = buffers_[prng_step.inputs[0]];
  // Outputs are another matrix of PRNG state, and a buffer to be filled.
  // Output state shape must match input state shape.
  assert(2 == prng_step.outputs.size());
  Buffer out_state = buffers_[prng_step.outputs[0]];
  assert(out_state.refinement->interior_shape == in_state.refinement->interior_shape);
  Buffer dest = buffers_[prng_step.outputs[1]];
  llvm::Type* floatPtrType = builder_.getFloatTy()->getPointerTo();
  llvm::Value* dest_arg = builder_.CreateBitCast(dest.base, floatPtrType);
  size_t dest_bytes = dest.refinement->interior_shape.byte_size();
  llvm::Value* count = IndexConst(dest_bytes / sizeof(uint32_t));
  std::vector<llvm::Value*> args{in_state.base, out_state.base, dest_arg, count};
  builder_.CreateCall(PrngStepFunction(), args, "");
}

void Compiler::Shape(const stripe::Special& shape) {
  // Input is a tensor. Output is a 1-dimensional array with number of elements
  // equal to the input tensor's number of dimensions. Write the size of each
  // input dimension to the corresponding element of the output tensor.
  assert(1 == shape.inputs.size());
  Buffer data = buffers_[shape.inputs[0]];
  size_t data_ndims = data.refinement->interior_shape.dims.size();
  assert(1 == shape.outputs.size());
  Buffer out = buffers_[shape.outputs[0]];
  assert(1 == out.refinement->interior_shape.dims.size());
  assert(data_ndims == out.refinement->interior_shape.elem_size());
  uint64_t elem_bits = bit_width(out.refinement->interior_shape.type);
  for (size_t i = 0; i < data_ndims; ++i) {
    uint64_t dim_size = data.refinement->interior_shape.dims[i].size;
    llvm::Value* val = builder_.getIntN(elem_bits, dim_size);
    llvm::Value* dest = builder_.CreateGEP(out.base, IndexConst(i));
    builder_.CreateStore(val, dest);
  }
}

void Compiler::AggInitAdd(const stripe::Special& agg_init) {
  // One output: a tensor to initialize.
  assert(0 == agg_init.inputs.size());
  assert(1 == agg_init.outputs.size());
  Buffer dest = buffers_[agg_init.outputs[0]];
  auto& dest_shape = dest.refinement->interior_shape;
  size_t bits = bit_width(dest_shape.type);
  llvm::Type* eltype = CType(dest_shape.type);
  llvm::Value* init_val = nullptr;
  if (is_float(dest_shape.type)) {
    init_val = llvm::ConstantFP::get(eltype, 0.0);
  } else if (is_int(dest_shape.type) || is_uint(dest_shape.type)) {
    auto apval = llvm::APInt::getNullValue(bits);
    init_val = llvm::ConstantInt::get(eltype, apval);
  } else if (DataType::BOOLEAN == dest_shape.type) {
    init_val = builder_.getFalse();
  }
  AggInit(dest, init_val);
}

void Compiler::AggInitMul(const stripe::Special& agg_init) {
  // One output: a tensor to initialize.
  assert(0 == agg_init.inputs.size());
  assert(1 == agg_init.outputs.size());
  Buffer dest = buffers_[agg_init.outputs[0]];
  auto& dest_shape = dest.refinement->interior_shape;
  size_t bits = bit_width(dest_shape.type);
  llvm::Type* eltype = CType(dest_shape.type);
  llvm::Value* init_val = nullptr;
  if (is_float(dest_shape.type)) {
    init_val = llvm::ConstantFP::get(eltype, 1.0);
  } else if (is_int(dest_shape.type) || is_uint(dest_shape.type)) {
    auto apval = llvm::APInt(bits, 1);
    init_val = llvm::ConstantInt::get(eltype, apval);
  } else if (DataType::BOOLEAN == dest_shape.type) {
    init_val = builder_.getTrue();
  }
  AggInit(dest, init_val);
}

void Compiler::AggInitMin(const stripe::Special& agg_init) {
  // One output: a tensor to initialize.
  assert(0 == agg_init.inputs.size());
  assert(1 == agg_init.outputs.size());
  Buffer dest = buffers_[agg_init.outputs[0]];
  auto& dest_shape = dest.refinement->interior_shape;
  size_t bits = bit_width(dest_shape.type);
  llvm::Type* eltype = CType(dest_shape.type);
  llvm::Value* init_val = nullptr;
  if (is_float(dest_shape.type)) {
    init_val = llvm::ConstantFP::getInfinity(eltype, /*Negative*/ false);
  } else if (is_int(dest_shape.type)) {
    auto apval = llvm::APInt::getSignedMaxValue(bits);
    init_val = llvm::ConstantInt::get(eltype, apval);
  } else if (is_uint(dest_shape.type)) {
    auto apval = llvm::APInt::getMaxValue(bits);
    init_val = llvm::ConstantInt::get(eltype, apval);
  } else if (DataType::BOOLEAN == dest_shape.type) {
    init_val = builder_.getTrue();
  }
  AggInit(dest, init_val);
}

void Compiler::AggInitMax(const stripe::Special& agg_init) {
  // One output: a tensor to initialize.
  assert(0 == agg_init.inputs.size());
  assert(1 == agg_init.outputs.size());
  Buffer dest = buffers_[agg_init.outputs[0]];
  auto& dest_shape = dest.refinement->interior_shape;
  size_t bits = bit_width(dest_shape.type);
  llvm::Type* eltype = CType(dest_shape.type);
  llvm::Value* init_val = nullptr;
  if (is_float(dest_shape.type)) {
    init_val = llvm::ConstantFP::getInfinity(eltype, /*Negative*/ true);
  } else if (is_int(dest_shape.type)) {
    auto apval = llvm::APInt::getSignedMinValue(bits);
    init_val = llvm::ConstantInt::get(eltype, apval);
  } else if (is_uint(dest_shape.type)) {
    auto apval = llvm::APInt::getMinValue(bits);
    init_val = llvm::ConstantInt::get(eltype, apval);
  } else if (DataType::BOOLEAN == dest_shape.type) {
    init_val = builder_.getFalse();
  }
  AggInit(dest, init_val);
}

void Compiler::AggInit(const Buffer& dest, llvm::Value* init_val) {
  // The initialization value depends on the refinement's agg_op.
  // Iterate over each dimension and write to each element.
  auto& dest_shape = dest.refinement->interior_shape;
  size_t dest_ndims = dest_shape.dims.size();

  // Generate the limit value for each dimension.
  std::vector<llvm::Value*> limits(dest_ndims);
  for (size_t i = 0; i < dest_ndims; ++i) {
    limits[i] = IndexConst(dest_shape.dims[i].size);
  }

  // Allocate an index variable for each loop.
  std::vector<llvm::Value*> idx_vars(dest_ndims);
  for (size_t i = 0; i < dest_ndims; ++i) {
    std::string name = "idx_" + std::to_string(i);
    idx_vars[i] = builder_.CreateAlloca(IndexType(), nullptr, name);
  }

  // Generate the initialization & limit test for each loop.
  std::vector<Loop> loops(dest_ndims);
  for (size_t i = 0; i < dest_ndims; ++i) {
    std::string name = std::to_string(i);
    CreateLoop(&loops[i], name);
    EnterLoop(&loops[i], idx_vars[i], IndexConst(0), limits[i]);
  }

  if (!init_val) {
    throw Error("Undefined agg_op init for " + to_string(dest_shape.type));
  }

  // Compute the element offset for these indexes.
  llvm::Value* dest_idx = IndexConst(0);
  for (size_t i = 0; i < dest_ndims; ++i) {
    llvm::Value* idx_val = builder_.CreateLoad(idx_vars[i]);
    llvm::Value* stride = IndexConst(dest_shape.dims[i].stride);
    idx_val = builder_.CreateMul(idx_val, stride);
    dest_idx = builder_.CreateAdd(dest_idx, idx_val);
  }
  llvm::Value* dest_element = builder_.CreateGEP(dest.base, dest_idx);
  builder_.CreateStore(init_val, dest_element);

  // Terminate the loop by incrementing the index and repeating.
  for (size_t i = dest_ndims; i-- > 0;) {
    LeaveLoop(&loops[i], idx_vars[i]);
  }
}

void Compiler::ParallelFor(llvm::Value* refs, llvm::Value* idxs, size_t range, llvm::Function* block) {
  llvm::Type* ptrArrayType = builder_.getInt8Ty()->getPointerTo()->getPointerTo();
  llvm::Type* idxArrayType = IndexType()->getPointerTo();
  std::vector<llvm::Type*> blockArgTypes{ptrArrayType, idxArrayType, IndexType(), IndexType()};
  llvm::Type* blockType = llvm::FunctionType::get(builder_.getVoidTy(), blockArgTypes, false);
  llvm::Type* blockPtrType = blockType->getPointerTo();
  std::vector<llvm::Type*> fnArgTypes{ptrArrayType, idxArrayType, IndexType(), blockPtrType};
  auto fnType = llvm::FunctionType::get(builder_.getVoidTy(), fnArgTypes, false);
  auto fn = module_->getOrInsertFunction("ParallelFor", fnType).getCallee();
  std::vector<llvm::Value*> argvals{refs, idxs, IndexConst(range), block};
  builder_.CreateCall(fn, argvals, "");
}

void Compiler::Scatter(const stripe::Special& scatter) {
  // Three inputs: "data", "indices", "shape"; one output.
  // For each value in "data", look up the corresponding location from
  // "indices", then write the value to that location in the output.
  // The "shape" parameter is a one-dimensional array of integers specifying
  // the shape of the output tensor. We don't need this information since the
  // output has already been allocated and we already know its shape.
  assert(3 == scatter.inputs.size());
  Buffer data = buffers_[scatter.inputs[0]];
  auto& data_shape = data.refinement->interior_shape;
  Buffer indices = buffers_[scatter.inputs[1]];
  auto& indices_shape = indices.refinement->interior_shape;
  assert(1 == scatter.outputs.size());
  Buffer output = buffers_[scatter.outputs[0]];
  auto& output_shape = output.refinement->interior_shape;
  assert(output_shape == buffers_[scatter.inputs[2]].refinement->interior_shape);

  // Build a loop nest over each dimension of the data.
  size_t data_ndims = data_shape.dims.size();
  std::vector<llvm::Value*> limits(data_ndims);
  std::vector<llvm::Value*> idx_vars(data_ndims);
  for (size_t i = 0; i < data_ndims; ++i) {
    std::string name = "idx_" + std::to_string(i);
    idx_vars[i] = builder_.CreateAlloca(IndexType(), nullptr, name);
    limits[i] = IndexConst(data_shape.dims[i].size);
  }

  // Generate the initialization & limit test for each loop
  std::vector<Loop> loops(data_ndims);
  for (size_t i = 0; i < data_ndims; ++i) {
    std::string name = std::to_string(i);
    CreateLoop(&loops[i], name);
    EnterLoop(&loops[i], idx_vars[i], IndexConst(0), limits[i]);
  }

  // Body of the loop nest
  // Look up the value of the data element.
  llvm::Value* data_idx = IndexConst(0);
  for (size_t i = 0; i < data_ndims; ++i) {
    llvm::Value* idx_val = builder_.CreateLoad(idx_vars[i]);
    llvm::Value* stride = IndexConst(data_shape.dims[i].stride);
    idx_val = builder_.CreateMul(idx_val, stride);
    data_idx = builder_.CreateAdd(data_idx, idx_val);
  }
  llvm::Value* data_element = builder_.CreateGEP(data.base, data_idx);
  llvm::Value* data_val = builder_.CreateLoad(data_element);

  // Look up the index value corresponding to this position.
  auto idx_var_iter = idx_vars.begin();
  llvm::Value* indirect_idx = IndexConst(0);
  for (size_t i = 0; i < indices_shape.dims.size(); ++i) {
    llvm::Value* idx_val = builder_.CreateLoad(*idx_var_iter++);
    llvm::Value* stride = IndexConst(indices_shape.dims[i].stride);
    idx_val = builder_.CreateMul(idx_val, stride);
    indirect_idx = builder_.CreateAdd(indirect_idx, idx_val);
  }
  llvm::Value* indirect_element = builder_.CreateGEP(indices.base, indirect_idx);
  llvm::Value* indirect_val = builder_.CreateLoad(indirect_element);

  // Clamp the index value to the output range; that is, the size of the first
  // output dimension.
  bool ind_signed = !is_uint(indices_shape.type);
  auto cast_op = llvm::CastInst::getCastOpcode(indirect_val, ind_signed, IndexType(), false);
  indirect_val = builder_.CreateCast(cast_op, indirect_val, IndexType());
  llvm::Value* ind_limit = IndexConst(output_shape.dims[0].size);
  llvm::Value* must_clamp = builder_.CreateICmpUGT(indirect_val, ind_limit);
  indirect_val = builder_.CreateSelect(must_clamp, ind_limit, indirect_val);

  // Using the indirect index as the zero-dimension coordinate and continuing
  // with the remaining idx_vars unused by the indices lookup, locate the
  // output element, then write the data_val. This is not actually a simple
  // store, as one might expect, and as the keras & tensorflow documentation
  // seemingly implies, but an aggregation (add).
  llvm::Value* indirect_stride = IndexConst(output_shape.dims[0].stride);
  llvm::Value* dest_idx = builder_.CreateMul(indirect_val, indirect_stride);
  for (size_t i = 1; i < output_shape.dims.size(); ++i) {
    llvm::Value* idx_val = builder_.CreateLoad(*idx_var_iter++);
    llvm::Value* stride = IndexConst(output_shape.dims[i].stride);
    idx_val = builder_.CreateMul(idx_val, stride);
    dest_idx = builder_.CreateAdd(dest_idx, idx_val);
  }
  llvm::Value* dest_element = builder_.CreateGEP(output.base, dest_idx);
  llvm::Value* old_dest_val = builder_.CreateLoad(dest_element);
  if (is_float(data_shape.type)) {
    data_val = builder_.CreateFAdd(old_dest_val, data_val);
  } else {
    data_val = builder_.CreateAdd(old_dest_val, data_val);
  }
  builder_.CreateStore(data_val, dest_element);

  // Terminate the loop by incrementing the index and repeating.
  for (size_t i = data_ndims; i-- > 0;) {
    LeaveLoop(&loops[i], idx_vars[i]);
  }
}

void Compiler::Gather(const stripe::Special& gather) {
  // Two inputs: "data" and "indices", and one output.
  // For each value in "indices", look up the corresponding element from "data"
  // and write it into the output.
  assert(2 == gather.inputs.size());
  Buffer data = buffers_[gather.inputs[0]];
  auto& data_shape = data.refinement->interior_shape;
  Buffer indices = buffers_[gather.inputs[1]];
  auto& indices_shape = indices.refinement->interior_shape;
  assert(1 == gather.outputs.size());
  Buffer dest = buffers_[gather.outputs[0]];
  auto& dest_shape = dest.refinement->interior_shape;
  // Build a loop nest for each dimension of "indices".
  // Look up the index value, which must be an integer.
  // Clamp the index value to the range of dimension 0 of "data".
  // Build a loop nest for each remaining dimension of "data".
  // Copy the element value from "data" to "dest".
  size_t outer_ndims = indices_shape.dims.size();
  size_t inner_ndims = data_shape.dims.size() - 1;
  size_t dest_ndims = dest_shape.dims.size();
  assert(dest_ndims == outer_ndims + inner_ndims);

  // Compute the limit for each loop based on the tensor shape extents
  std::vector<llvm::Value*> limits(dest_ndims);
  for (size_t i = 0; i < outer_ndims; ++i) {
    limits[i] = IndexConst(indices_shape.dims[i].size);
  }
  for (size_t i = 0; i < inner_ndims; ++i) {
    limits[i + outer_ndims] = IndexConst(data_shape.dims[1 + i].size);
  }

  // Allocate index vars for each loop
  std::vector<llvm::Value*> idx_vars(dest_ndims);
  for (size_t i = 0; i < dest_ndims; ++i) {
    std::string name = "idx_" + std::to_string(i);
    idx_vars[i] = builder_.CreateAlloca(IndexType(), nullptr, name);
  }

  // Generate the initialization & limit test for each loop
  std::vector<Loop> loops(dest_ndims);
  for (size_t i = 0; i < dest_ndims; ++i) {
    std::string name = std::to_string(i);
    CreateLoop(&loops[i], name);
    EnterLoop(&loops[i], idx_vars[i], IndexConst(0), limits[i]);
  }

  // Body of the loop nest: look up an index value from "indexes" using the
  // outer_ndims, then use that index plus the inner_ndims to look up a value
  // from "data"; use all of the dest_ndims to write that into "dest".
  llvm::Value* outer_idx = IndexConst(0);
  for (size_t i = 0; i < outer_ndims; ++i) {
    llvm::Value* idx_val = builder_.CreateLoad(idx_vars[i]);
    llvm::Value* stride = IndexConst(indices_shape.dims[i].stride);
    idx_val = builder_.CreateMul(idx_val, stride);
    outer_idx = builder_.CreateAdd(outer_idx, idx_val);
  }
  llvm::Value* indirect_element = builder_.CreateGEP(indices.base, outer_idx);
  llvm::Value* indirect_val = builder_.CreateLoad(indirect_element);

  // Clamp the indirect val, for safety. Convert to unsigned integer, thereby
  // throwing away negative values, then compare against the destination's
  // zero'th-dimension size.
  bool ind_signed = !is_uint(indices_shape.type);
  auto cast_op = llvm::CastInst::getCastOpcode(indirect_val, ind_signed, IndexType(), false);
  indirect_val = builder_.CreateCast(cast_op, indirect_val, IndexType());
  llvm::Value* ind_limit = IndexConst(data_shape.dims[0].size);
  llvm::Value* must_clamp = builder_.CreateICmpUGT(indirect_val, ind_limit);
  indirect_val = builder_.CreateSelect(must_clamp, ind_limit, indirect_val);

  llvm::Value* indirect_stride = IndexConst(data_shape.dims[0].stride);
  llvm::Value* inner_idx = builder_.CreateMul(indirect_val, indirect_stride);
  for (size_t i = 0; i < inner_ndims; ++i) {
    llvm::Value* idx_val = builder_.CreateLoad(idx_vars[i + outer_ndims]);
    llvm::Value* stride = IndexConst(data_shape.dims[1 + i].stride);
    idx_val = builder_.CreateMul(idx_val, stride);
    inner_idx = builder_.CreateAdd(inner_idx, idx_val);
  }
  llvm::Value* data_element = builder_.CreateGEP(data.base, inner_idx);
  llvm::Value* data_val = builder_.CreateLoad(data_element);

  llvm::Value* dest_idx = IndexConst(0);
  for (size_t i = 0; i < dest_ndims; ++i) {
    llvm::Value* idx_val = builder_.CreateLoad(idx_vars[i]);
    llvm::Value* stride = IndexConst(dest_shape.dims[i].stride);
    idx_val = builder_.CreateMul(idx_val, stride);
    dest_idx = builder_.CreateAdd(dest_idx, idx_val);
  }
  llvm::Value* dest_element = builder_.CreateGEP(dest.base, dest_idx);
  builder_.CreateStore(data_val, dest_element);

  // Terminate the loop by incrementing the index and repeating.
  for (size_t i = dest_ndims; i-- > 0;) {
    LeaveLoop(&loops[i], idx_vars[i]);
  }
}

void Compiler::CreateLoop(Loop* loop, std::string name) {
  llvm::Function* func = builder_.GetInsertBlock()->getParent();
  loop->init = llvm::BasicBlock::Create(context_, "init_" + name, func);
  loop->test = llvm::BasicBlock::Create(context_, "test_" + name, func);
  loop->body = llvm::BasicBlock::Create(context_, "body_" + name, func);
  loop->done = llvm::BasicBlock::Create(context_, "done_" + name, func);
  builder_.CreateBr(loop->init);
}

void Compiler::EnterLoop(Loop* loop, llvm::Value* variable, llvm::Value* init, llvm::Value* limit) {
  builder_.SetInsertPoint(loop->init);
  builder_.CreateStore(init, variable);
  builder_.CreateBr(loop->test);
  builder_.SetInsertPoint(loop->test);
  llvm::Value* idx_val = builder_.CreateLoad(variable);
  llvm::Value* go = builder_.CreateICmpULT(idx_val, limit);
  builder_.CreateCondBr(go, loop->body, loop->done);
  builder_.SetInsertPoint(loop->body);
}

void Compiler::LeaveLoop(Loop* loop, llvm::Value* variable) {
  llvm::Value* index = builder_.CreateLoad(variable);
  index = builder_.CreateAdd(index, IndexConst(1));
  builder_.CreateStore(index, variable);
  builder_.CreateBr(loop->test);
  builder_.SetInsertPoint(loop->done);
}

Compiler::Scalar Compiler::Cast(Scalar v, DataType to_type) {
  if (v.type == to_type) {
    return v;
  }
  llvm::Type* to_llvmtype = CType(to_type);
  bool from_signed = is_int(v.type) || is_float(v.type);
  bool to_signed = is_int(to_type) || is_float(to_type);
  auto op = llvm::CastInst::getCastOpcode(v.value, from_signed, to_llvmtype, to_signed);
  llvm::Value* ret = builder_.CreateCast(op, v.value, to_llvmtype);
  return Scalar{ret, to_type};
}

Compiler::Scalar Compiler::CheckNotFloat(Scalar v) {
  if (is_float(v.type) || v.type == DataType::INVALID) {
    throw Error("Expected non-float, actually found " + to_string(v.type));
  }
  return v;
}

llvm::Type* Compiler::CType(DataType type) {
  switch (type) {
    case DataType::BOOLEAN:
      return builder_.getInt1Ty();
    case DataType::INT8:
    case DataType::UINT8:
      return builder_.getInt8Ty();
    case DataType::INT16:
    case DataType::UINT16:
      return builder_.getInt16Ty();
    case DataType::INT32:
    case DataType::UINT32:
      return builder_.getInt32Ty();
    case DataType::INT64:
    case DataType::UINT64:
      return builder_.getInt64Ty();
    case DataType::FLOAT16:
      return builder_.getHalfTy();
    case DataType::FLOAT32:
      return builder_.getFloatTy();
    case DataType::FLOAT64:
      return builder_.getDoubleTy();
    default:
      throw Error("Invalid type: " + to_string(type));
  }
  return builder_.getVoidTy();
}

llvm::Value* Compiler::ElementPtr(const Buffer& buf) {
  // Ask the source refinement to generate an access path, in the form of
  // a sequence of indexes to scale and sum. Load each index value, multiply,
  // and the result is an offset from the buffer base address.
  llvm::Value* offset = Eval(buf.refinement->FlatAccess());
  std::vector<llvm::Value*> idxList{offset};
  return builder_.CreateGEP(buf.base, idxList, buf.refinement->into() + "[]");
}

llvm::Value* Compiler::Eval(const stripe::Affine& access) {
  llvm::Value* offset = IndexConst(0);
  for (auto& term : access.getMap()) {
    llvm::Value* indexVal = nullptr;
    if (!term.first.empty()) {
      llvm::Value* indexVar = indexes_[term.first].variable;
      indexVal = builder_.CreateLoad(indexVar);
      llvm::Value* multiplier = IndexConst(term.second);
      indexVal = builder_.CreateMul(indexVal, multiplier);
    } else {
      indexVal = IndexConst(term.second);
    }
    offset = builder_.CreateAdd(offset, indexVal);
  }
  return offset;
}

void Compiler::OutputType(llvm::Value* ret, const stripe::Intrinsic& intrinsic) {
  assert(1 == intrinsic.outputs.size());
  scalars_[intrinsic.outputs[0]] = Scalar{ret, intrinsic.type};
  ret->setName(intrinsic.outputs[0]);
}

void Compiler::OutputBool(llvm::Value* ret, const stripe::Intrinsic& intrinsic) {
  assert(1 == intrinsic.outputs.size());
  scalars_[intrinsic.outputs[0]] = Scalar{ret, DataType::BOOLEAN};
  ret->setName(intrinsic.outputs[0]);
}

void Compiler::CallIntrinsicFunc(const stripe::Intrinsic& stmt, const char* name_f32, const char* name_f64,
                                 const size_t numParamsIn) {
  size_t numParams = stmt.inputs.size();
  switch (numParamsIn) {
    case 1:
      if (numParams != 1) {
        throw std::runtime_error("CallIntrinsicFunction expects 1 parameter");
      }
      break;
    case 2:
      if (numParams != 2) {
        throw std::runtime_error("CallIntrinsicFunction expects 2 parameters");
      }
      break;
    default:
      throw std::runtime_error("CallIntrinsicFunction expects 1 or 2 parameters");
  }

  Scalar op1 = Cast(scalars_[stmt.inputs[0]], stmt.type);
  Scalar op2;

  if (numParamsIn == 2) {
    op2 = Cast(scalars_[stmt.inputs[1]], stmt.type);
  }

  std::vector<llvm::Value*> argvals;
  argvals.emplace_back(op1.value);
  if (numParamsIn == 2) {
    argvals.emplace_back(op2.value);
  }

  // C intrinsics come in either f32 or f64 flavors. We'll use f32 for single
  // and half-precision float inputs, f64 for ints and doubles
  bool use_f32 = (stmt.type == DataType::FLOAT16 || stmt.type == DataType::FLOAT32);
  const char* name = use_f32 ? name_f32 : name_f64;
  llvm::Type* ctype = use_f32 ? builder_.getFloatTy() : builder_.getDoubleTy();
  std::vector<llvm::Type*> argtypes;
  argtypes.emplace_back(ctype);
  if (numParamsIn == 2) {
    argtypes.emplace_back(ctype);
  }
  auto functype = llvm::FunctionType::get(ctype, argtypes, false);
  auto func = module_->getOrInsertFunction(name, functype).getCallee();
  llvm::Value* ret = builder_.CreateCall(func, argvals, "");
  OutputType(ret, stmt);
}

llvm::Type* Compiler::IndexType() {
  unsigned archbits = module_->getDataLayout().getPointerSizeInBits();
  return llvm::IntegerType::get(context_, archbits);
}

llvm::Value* Compiler::IndexConst(ssize_t val) {
  llvm::Type* ssizetype = IndexType();
  return llvm::ConstantInt::get(ssizetype, val);
}

llvm::FunctionType* Compiler::BlockType(const stripe::Block& block) {
  // Generate a type for the function which will implement this block.
  std::vector<llvm::Type*> param_types;
  if (getCompileFor(block) != THREADED_BLOCK) {
    // This block function will be invoked directly.
    // Each buffer base address will be provided as a parameter.
    for (const auto& ref : block.refs) {
      param_types.push_back(CType(ref.interior_shape.type)->getPointerTo());
    }
    // Following the buffers, a parameter will provide the initial value for
    // each of the block's indexes.
    for (size_t i = 0; i < block.idxs.size(); ++i) {
      param_types.push_back(IndexType());
    }
  } else {
    // This block function will be executed via ParallelFor.
    // First parameter is a pointer to an array of refinement base addresses.
    // Since all block functions must have the same type signature, we will
    // define this as int8_t** instead and bitcast whenever we use it.
    auto int8PtrType = builder_.getInt8Ty()->getPointerTo();
    param_types.push_back(int8PtrType->getPointerTo());
    // Next, pointer to an array of index init values.
    param_types.push_back(IndexType()->getPointerTo());
    // Third and fourth, composite index range begin and end values.
    param_types.push_back(IndexType());
    param_types.push_back(IndexType());
  }
  // Blocks never return a value.
  llvm::Type* return_type = builder_.getVoidTy();
  return llvm::FunctionType::get(return_type, param_types, false);
}

llvm::Value* Compiler::XSMMDispatchFunction(llvm::Type* alphaPtrType, llvm::Type* betaPtrType, llvm::Type* aPtrType,
                                            llvm::Type* bPtrType, llvm::Type* cPtrType,
                                            const std::string& functionName) {
  llvm::Type* iptr = llvm::Type::getInt32PtrTy(context_);
  std::vector<llvm::Type*> param_types{
      aPtrType,  // a
      bPtrType,  // b
      cPtrType,  // c
  };
  llvm::FunctionType* rftype = llvm::FunctionType::get(builder_.getVoidTy(), param_types, false);
  std::vector<llvm::Type*> argtypes{
      builder_.getInt32Ty(),  // m
      builder_.getInt32Ty(),  // n
      builder_.getInt32Ty(),  // k
      iptr,                   // lda
      iptr,                   // ldb
      iptr,                   // ldc
      alphaPtrType,           // alpha
      betaPtrType,            // beta
      iptr,                   // flags
      iptr,                   // prefetch
  };
  auto functype = llvm::FunctionType::get(rftype->getPointerTo(), argtypes, false);
  return module_->getOrInsertFunction(functionName.c_str(), functype).getCallee();
}

llvm::Value* Compiler::Malloc(size_t size) {
  std::vector<llvm::Type*> argtypes{IndexType()
  // MacOS RT doesn't have the align_alloc function and the allocations
  // on it are 16 bytes aligned.
#ifndef __APPLE__
                                        ,
                                    IndexType()
#endif  // __APPLE__
  };
  llvm::Type* rettype = builder_.getInt8PtrTy();
  auto functype = llvm::FunctionType::get(rettype, argtypes, false);
#ifdef __APPLE__
  const char* funcname = "malloc";
#elif defined(_WIN32)  // !__APPLE__
  const char* funcname = "__aligned_malloc";
#else                  // !__APPLE__ && !_WIN32
  const char* funcname = "aligned_alloc";
#endif                 // __APPLE__
  auto func = module_->getOrInsertFunction(funcname, functype).getCallee();
  auto buffer = builder_.CreateCall(func,
                                    {
#ifdef _WIN32
                                        IndexConst(size), IndexConst(VECTOR_MEM_ALIGNMENT)},
#else  // !_WIN32
#ifndef __APPLE__
                                        IndexConst(VECTOR_MEM_ALIGNMENT),
#endif  // __APPLE__
                                        IndexConst(size)},
#endif  // _WIN32
                                    "");
  return buffer;
}  // namespace cpu

llvm::Value* Compiler::RunTimeLogEntry(void) {
  std::vector<llvm::Type*> argtypes{
      builder_.getInt8Ty()->getPointerTo(),
      builder_.getInt8Ty()->getPointerTo(),
      builder_.getFloatTy(),
  };
  llvm::Type* rettype = llvm::Type::getVoidTy(context_);
  auto functype = llvm::FunctionType::get(rettype, argtypes, false);
  const char* funcname = "RunTimeLogEntry";
  return module_->getOrInsertFunction(funcname, functype).getCallee();
}

void Compiler::EmitRunTimeLogEntry(const std::string& str, const std::string& extra, llvm::Value* value) {
  std::vector<llvm::Value*> log_args{
      builder_.CreateGlobalStringPtr(str),
      builder_.CreateGlobalStringPtr(extra),
  };
  if (value) {
    log_args.push_back(value);
  } else {
    log_args.push_back(llvm::ConstantFP::get(builder_.getFloatTy(), 0.0));
  }
  builder_.CreateCall(RunTimeLogEntry(), log_args, "");
}

void Compiler::Free(llvm::Value* buffer) {
  llvm::Type* ptrtype = builder_.getInt8PtrTy();
  std::vector<llvm::Type*> argtypes{ptrtype};
  llvm::Type* rettype = llvm::Type::getVoidTy(context_);
  auto functype = llvm::FunctionType::get(rettype, argtypes, false);
#ifdef _WIN32
  const char* funcname = "_aligned_free";
#else   // !_WIN32
  const char* funcname = "free";
#endif  // _WIN32
  auto func = module_->getOrInsertFunction(funcname, functype).getCallee();
  builder_.CreateCall(func, {buffer}, "");
}

llvm::Value* Compiler::PrngStepFunction(void) {
  llvm::Type* floatPtrType = builder_.getFloatTy()->getPointerTo();
  llvm::Type* int32ptrType = builder_.getInt32Ty()->getPointerTo();
  std::vector<llvm::Type*> argtypes{int32ptrType, int32ptrType, floatPtrType, IndexType()};
  llvm::Type* rettype = llvm::Type::getVoidTy(context_);
  auto functype = llvm::FunctionType::get(rettype, argtypes, false);
  const char* funcname = "prng_step";
  return module_->getOrInsertFunction(funcname, functype).getCallee();
}

llvm::Value* Compiler::ReadCycleCounter(void) {
  auto functype = llvm::FunctionType::get(builder_.getInt64Ty(), {}, false);
  const char* funcname = "llvm.readcyclecounter";
  auto func = module_->getOrInsertFunction(funcname, functype).getCallee();
  auto divisor = IndexConst(1000);
  auto rcc = builder_.CreateCall(func, {}, "");
  auto div = builder_.CreateUDiv(rcc, divisor);
  return div;
}

void Compiler::ProfileBlockEnter(const stripe::Block& block) {
  if (!config_.profile_block_execution) {
    return;
  }
  // allocate counter variable
  std::string block_id = ProfileBlockID(block);
  std::string profile_count_name = profile_count_name_ + block_id;
  module_->getOrInsertGlobal(profile_count_name, IndexType());
  auto profile_count_gval = module_->getNamedGlobal(profile_count_name);
  profile_count_gval->setInitializer(llvm::Constant::getNullValue(IndexType()));
  // increment the execution count for this pass through the block
  builder_.CreateAtomicRMW(llvm::AtomicRMWInst::BinOp::Add, profile_count_gval, IndexConst(1),
                           llvm::AtomicOrdering::Monotonic);
  // allocate timing variable
  std::string profile_ticks_name = profile_ticks_name_ + block_id;
  module_->getOrInsertGlobal(profile_ticks_name, IndexType());
  auto profile_ticks_gval = module_->getNamedGlobal(profile_ticks_name);
  profile_ticks_gval->setInitializer(llvm::Constant::getNullValue(IndexType()));
  // Subtract the current rdtsc value from the saved tick count. This will
  // give us a temporarily invalid base value, which we will correct by adding
  // the ending rdtsc value back in when the block finishes.
  builder_.CreateAtomicRMW(llvm::AtomicRMWInst::BinOp::Sub, profile_ticks_gval, ReadCycleCounter(),
                           llvm::AtomicOrdering::Monotonic);
}

void Compiler::ProfileBlockLeave(const stripe::Block& block) {
  if (!config_.profile_block_execution) {
    return;
  }
  // Add the current rdtsc back into the elapsed time counter, which both
  // corrects for the bias we introduced on function entry and accumulates
  // the elapsed time into the running total.
  std::string block_id = ProfileBlockID(block);
  std::string profile_ticks_name = profile_ticks_name_ + block_id;
  auto profile_ticks_gval = module_->getNamedGlobal(profile_ticks_name);
  builder_.CreateAtomicRMW(llvm::AtomicRMWInst::BinOp::Add, profile_ticks_gval, ReadCycleCounter(),
                           llvm::AtomicOrdering::Monotonic);
}

void Compiler::ProfileLoopEnter(const stripe::Block& block) {
  if (!config_.profile_loop_body) {
    return;
  }
  // Count the time spent in the loop body, excluding increment/condition
  // overhead
  std::string block_id = ProfileBlockID(block);
  std::string profile_name = profile_loop_body_name_ + block_id;
  module_->getOrInsertGlobal(profile_name, IndexType());
  auto profile_gval = module_->getNamedGlobal(profile_name);
  profile_gval->setInitializer(llvm::Constant::getNullValue(IndexType()));
  builder_.CreateAtomicRMW(llvm::AtomicRMWInst::BinOp::Sub, profile_gval, ReadCycleCounter(),
                           llvm::AtomicOrdering::Monotonic);
}

void Compiler::ProfileLoopLeave(const stripe::Block& block) {
  if (!config_.profile_loop_body) {
    return;
  }
  std::string block_id = ProfileBlockID(block);
  std::string profile_name = profile_loop_body_name_ + block_id;
  auto profile_gval = module_->getNamedGlobal(profile_name);
  builder_.CreateAtomicRMW(llvm::AtomicRMWInst::BinOp::Add, profile_gval, ReadCycleCounter(),
                           llvm::AtomicOrdering::Monotonic);
}

std::string Compiler::ProfileBlockID(const stripe::Block& block) {
  return block.name + "@" + std::to_string((uintptr_t)&block);
}

void Compiler::PrintOutputAssembly(llvm::TargetMachine* machine) {
  std::string outputStr;
  {
    llvm::raw_string_ostream stream(outputStr);
    llvm::buffer_ostream pstream(stream);
    llvm::legacy::PassManager pm;
    machine->addPassesToEmitFile(pm, pstream, nullptr, llvm::CGFT_AssemblyFile);
    pm.run(*module_);
  }
  llvm::errs() << outputStr << "\n";
}

CompileFor Compiler::getCompileFor(const stripe::Block& block) {
  if (block.has_tag("xsmm")) {
    // xsmm and cpu_threads tags are incompatible.
    assert(!block.has_tag("cpu_thread"));
    return XSMM_BLOCK;
  } else if (block.has_tag("cpu_thread") && block.idxs_product() > 1) {
    return THREADED_BLOCK;
  }

  return NORMAL_BLOCK;
}

}  // namespace cpu
}  // namespace targets
}  // namespace tile
}  // namespace vertexai
